X3

RedSky

导航

C# Sql帮助类,可扩展

查看代码
[System.AttributeUsage(AttributeTargets.Class | AttributeTargets.Struct, Inherited = false, AllowMultiple = false)]
public class DbTableAttribute : Attribute
{
    public string Name { get; set; }
    public string Charset { get; set; }
    public string Collate { get; set; }
}
[System.AttributeUsage(AttributeTargets.Property, Inherited = false, AllowMultiple = false)]
public class DbColumnAttribute : Attribute
{
    public string Name { get; set; }
    /// <summary>
    /// <para>type[(length)] [[primary key]|[unique]] [unsigned] [zerofill] [not null or null] [default your_value] [comment 'your comment'] [collate 'your encoding'] ...etc.</para>
    /// <para>bigint auto_increment</para>
	/// <para>int default '0'</para>
	/// <para>varchar(50) null default null collate 'utf8_general_ci'</para>
    /// <para>datetime null</para>
    /// <para>datetime null default 'localtime'</para>
	/// <para>timestamp not null default current_timestamp on update current_timestamp</para>
    /// <para>bit(1) null default b'0' comment 'balabala'</para>
    ///  </summary>
    public string Desc { get; set; }
    /// <summary>
    /// Index1,Indx 2,Index3,primary key,unique
    /// </summary>
    public string Index { get; set; }
    public bool NotInsert { get; set; }
}
public abstract class DbContext<TTransaction, TConnection, TCommand, TParameter, TDataAdapter, TDataReader> : IDisposable
        where TTransaction : DbTransaction
    where TConnection : DbConnection
    where TCommand : DbCommand
    where TParameter : DbParameter
    where TDataAdapter : DbDataAdapter
    where TDataReader : DbDataReader
{
    protected string connectStr;
    protected TTransaction transaction = null;
    protected TConnection connection = null;
    public TConnection Connection
    {
        get { return connection; }
        set { connection = value; }
    }
    public virtual int ExecuteNonQuery(string sql, params TParameter[] parameters)
    {
        Connect();
        int i;
        using (TCommand cmd = CreateCommand(sql, parameters))
        {
            if (this.transaction != null)
                cmd.Transaction = (TTransaction)this.transaction;
            i = cmd.ExecuteNonQuery();
        }
        return i;
    }
    public virtual object ExecuteScalar(string sql, params TParameter[] parameters)
    {
        Connect();
        using (TCommand cmd = CreateCommand(sql, parameters))
        {
            return cmd.ExecuteScalar();
        }
    }
    public virtual object ExecuteReader(string sql, params TParameter[] parameters)
    {
        Connect();
        using (TCommand cmd = CreateCommand(sql, parameters))
        {
            using (var reader = cmd.ExecuteReader())
            {
                while (reader.NextResult()) { }
                if (reader.Read() && reader.FieldCount > 0)
                {
                    DataTable dt = new DataTable();
                    dt.Load(reader);
                    return dt;
                }
                else
                    return reader.RecordsAffected;
            }
        }
    }
    public virtual DataTable Query(string sql, params TParameter[] parameters)
    {
        Connect();
        using (TCommand cmd = CreateCommand(sql, parameters))
        {
            using (TDataAdapter adapter = (TDataAdapter)Activator.CreateInstance(typeof(TDataAdapter), cmd))
            {
                DataSet ds = new DataSet();
                adapter.Fill(ds, "ds");
                return ds.Tables.Count > 0 ? ds.Tables[0] : null;
            }
        }
    }
    public virtual TCommand CreateCommand(string sql, params TParameter[] parameters)
    {
        TCommand cmd = (TCommand)Connection.CreateCommand();
        cmd.CommandText = sql;
        if (parameters != null && parameters.Length > 0)
            cmd.Parameters.AddRange(parameters);
        return cmd;
    }
    public virtual int AddColumn<Table>(DbColumnAttribute dbColumnAttribute)
    {
        if (dbColumnAttribute == null)
            dbColumnAttribute = new DbColumnAttribute();
        Type tableType = typeof(Table);
        DbTableAttribute dbTableAttribute = GetDbTableAttribute<Table>();
        return AddColumn(tableType.Name, dbColumnAttribute.Name, dbColumnAttribute.Desc);
    }
    public virtual int AddColumn(string tableName, string columnName, string options)
    {
        string sql = $"alter table '{tableName}' add  `{columnName}` {options}";
        int r = ExecuteNonQuery(sql);
        return r;
    }
    public virtual void AddIndex(string tableName, string index, string[] columns)
    {
        string sql = $"create index `{index}` on `{tableName}` (`{string.Join(",", columns)}`)";
        ExecuteNonQuery(sql);
    }
    public virtual void DeleteIndex(string tableName, string index)
    {
        ExecuteNonQuery($"drop index `{index}` on `{tableName}`");
    }
    public virtual List<Table> GetList<Table>(string sql, params TParameter[] parameters)
    {
        DataTable dt = Query(sql, parameters);
        return TableToList<Table>(dt);
    }
    public virtual void Connect(bool reconnect = false)
    {
        if (!reconnect && connection != null && connection.State != System.Data.ConnectionState.Closed)
            return;
        connection?.Dispose();
        connection = (TConnection)Activator.CreateInstance(typeof(TConnection), this.connectStr);
        connection.Open();
    }
    public virtual int CreateTable<Table>()
    {
        Type tableType = typeof(Table);
        var tableAttribute = GetDbTableAttribute<Table>();
        StringBuilder sb = new StringBuilder($"create table `{tableAttribute.Name}`");
        StringBuilder columns = new StringBuilder();
        var propertyInfos = tableType.GetProperties(BindingFlags.Instance | BindingFlags.Public);
        foreach (PropertyInfo propertyInfo in propertyInfos)
        {
            var columnInfo = propertyInfo.GetCustomAttribute<DbColumnAttribute>();
            if (columnInfo == null) continue;
            columns.AppendLine($"\t`{columnInfo.Name}` {columnInfo.Desc},");
        }
        if (columns.Length > 0)
        {
            string sColumns = columns.ToString();
            sColumns = sColumns.Remove(sColumns.LastIndexOf(','), 1);
            sb.Append($" (\r\n{sColumns})\r\n");
        }
        return ExecuteNonQuery(sb.ToString());
    }
    public abstract bool ExistColumn(string column, string table);
    public abstract bool ExistIndex(string tableName, string index, string columnName = null);
    public abstract bool ExistTable(string table);
    public virtual Dictionary<string, List<string>> GetIndexs(Type table)
    {
        Dictionary<string, List<string>> indexDict = new Dictionary<string, List<string>>();
        foreach (var propertyInfo in table.GetProperties(BindingFlags.Instance | BindingFlags.Public))
        {
            var column = propertyInfo.GetCustomAttribute<DbColumnAttribute>();
            if (column == null) continue;
            if (!string.IsNullOrWhiteSpace(column.Index))
            {
                foreach (var index in column.Index.Split(','))
                {
                    if (!indexDict.ContainsKey(index))
                        indexDict.Add(index, new List<string> { column.Name });
                    else
                        indexDict[index].Add(column.Name);
                }
            }
        }
        return indexDict;
    }
    public virtual void UseServer(string connectStr)
    {
        this.connectStr = connectStr;
        Connect(true);
    }
    public virtual void Repair<Table>()
    {
        var tableType = typeof(Table);
        RepairTable<Table>();
        RepairColumns<Table>();
        RepairIndex(tableType);
    }
    public virtual void RepairTable<Table>()
    {
        var table = GetDbTableAttribute<Table>();
        var tableType = typeof(Table);
        ValidateTableAttribute(table, tableType);
        bool exist = ExistTable(table.Name);
        if (!exist)
            CreateTable<Table>();
    }
    public virtual void RepairColumns<Table>()
    {
        var table = GetDbTableAttribute<Table>();
        foreach (var propertyInfo in typeof(Table).GetProperties(BindingFlags.Instance | BindingFlags.Public))
        {
            var column = propertyInfo.GetCustomAttribute<DbColumnAttribute>();
            if (column == null) continue;
            if (ExistColumn(column.Name, table.Name))
                AddColumn(table.Name, column.Name, column.Desc);
        }
    }
    public virtual void RepairIndex(Type table)
    {
        Dictionary<string, List<string>> indexDict = GetIndexs(table);
        foreach (var index in indexDict.Keys)
        {
            foreach (var column in indexDict[index])
            {
                if (!ExistIndex(table.Name, index, column))
                {
                    if (ExistIndex(table.Name, index))
                        DeleteIndex(table.Name, index);
                    AddIndex(table.Name, index, indexDict[index].ToArray());
                    break;
                }
            }
        }
    }
    public virtual void Dispose()
    {
        transaction?.Dispose();
        connection?.Dispose();
    }
    public virtual int Insert<Table>(Table model)
    {
        Type tableType = typeof(Table);
        var propertyInfos = tableType.GetProperties(BindingFlags.Instance | BindingFlags.Public);
        Dictionary<string, object> dict = new Dictionary<string, object>();
        foreach (PropertyInfo propertyInfo in propertyInfos)
        {
            var columnInfo = propertyInfo.GetCustomAttribute<DbColumnAttribute>();
            if (columnInfo == null || columnInfo.NotInsert) continue;
            dict.Add(columnInfo.Name, propertyInfo.GetValue(model));
        }
        DbTableAttribute dbTableAttribute = GetDbTableAttribute<Table>();
        StringBuilder sb = new StringBuilder($"insert into {dbTableAttribute.Name}");
        sb.Append(string.Join(",", "(", dict.Keys.Select(s => $"{s}=@{s}"), ")"));
        sb.Append(string.Join(",", "(", dict.Keys.Select(s => $"@{s}"), ")"));
        TParameter[] parameters = new TParameter[dict.Count];
        int i = 0;
        using (var enumerator = dict.GetEnumerator())
        {
            KeyValuePair<string, object> item;
            var typeParam = typeof(TParameter);
            while (enumerator.MoveNext())
            {
                item = enumerator.Current;
                parameters[i++] = (TParameter)Activator.CreateInstance(typeParam, $"@{item.Key}", item.Value);
            }
        }
        return ExecuteNonQuery(sb.ToString(), parameters);
    }
    public virtual int DeleteByPrimaryKey<Table>(object value)
    {
        Type type = typeof(Table);
        var tableInfo = type.GetCustomAttribute<DbTableAttribute>();
        if (tableInfo == null) return 0;
        var properties = type.GetProperties(BindingFlags.Instance | BindingFlags.Public);
        DbColumnAttribute columnInfo;
        foreach (var propertyInfo in properties)
        {
            columnInfo = propertyInfo.GetCustomAttribute<DbColumnAttribute>();
            if (columnInfo == null || !columnInfo.Desc.ToLower().Contains("primary key")) continue;
            return Delete(tableInfo.Name, columnInfo.Name, value);
        }
        return 0;
    }
    public virtual int Delete(string table, string column, object value)
    {
        return ExecuteNonQuery($"delete from `{table}` where `{column}`=@{column}", (TParameter)Activator.CreateInstance(typeof(TParameter), $"@{column}", value));
    }
    public virtual int Delete(string table, string where, params TParameter[] parameters)
    {
        return ExecuteNonQuery($"delete from `{table}` where {where}", parameters);
    }
    public virtual void BeginTransaction()
    {
        Connect();
        transaction = (TTransaction)connection.BeginTransaction();
    }
    public virtual bool EndTransaction()
    {
        bool r;
        try
        {
            transaction.Commit();
            r = true;
        }
        catch
        {
            transaction.Rollback();
            r = false;
        }
        transaction.Dispose();
        transaction = null;
        return r;
    }

    public DbContext() { }
    public DbContext(Action<DbContext<TTransaction, TConnection, TCommand, TParameter, TDataAdapter, TDataReader>> callback) { callback(this); }
    public virtual List<Table> TableToList<Table>(DataTable dt)
    {
        var list = new List<Table>();
        if (dt == null || dt.Rows.Count <= 0) return list;
        Dictionary<string, PropertyInfo> dict = new Dictionary<string, PropertyInfo>();
        foreach (var propertyInfo in typeof(Table).GetProperties(BindingFlags.Instance | BindingFlags.Public))
        {
            var column = propertyInfo.GetCustomAttribute<DbColumnAttribute>();
            if (column == null) continue;
            dict.Add(column.Name, propertyInfo);
        }
        foreach (DataRow row in dt.Rows)
        {
            Table model = Activator.CreateInstance<Table>();
            foreach (DataColumn column in dt.Columns)
            {
                if (!dict.ContainsKey(column.ColumnName)) continue;
                var propertyInfo = dict[column.ColumnName];
                var value = row[column.ColumnName];
                if (value == null || value is DBNull) continue;
                if (propertyInfo.PropertyType == typeof(int))
                    propertyInfo.SetValue(model, Convert.ToInt32(value));
                else if (propertyInfo.PropertyType == typeof(byte))
                    propertyInfo.SetValue(model, Convert.ToByte(value));
                else if (propertyInfo.PropertyType == typeof(long))
                    propertyInfo.SetValue(model, Convert.ToInt64(value));
                else if (propertyInfo.PropertyType == typeof(float))
                    propertyInfo.SetValue(model, Convert.ToSingle(value));
                else if (propertyInfo.PropertyType == typeof(double))
                    propertyInfo.SetValue(model, Convert.ToDouble(value));
                else if (propertyInfo.PropertyType == typeof(decimal))
                    propertyInfo.SetValue(model, Convert.ToDecimal(value));
                else if (propertyInfo.PropertyType == typeof(DateTime))
                    propertyInfo.SetValue(model, Convert.ToDateTime(value));
                else if (propertyInfo.PropertyType == typeof(string))
                    propertyInfo.SetValue(model, value.ToString());
                else if (propertyInfo.PropertyType.IsEnum)
                    propertyInfo.SetValue(model, Enum.Parse(propertyInfo.PropertyType, value.ToString()));
            }
            list.Add(model);
        }
        return list;
    }
    public DbTableAttribute GetDbTableAttribute<Table>() => GetDbTableAttribute(typeof(Table));
    public virtual DbTableAttribute GetDbTableAttribute(Type t)
    {
        var tableAttribute = t.GetCustomAttribute<DbTableAttribute>();
        if (tableAttribute == null)
            tableAttribute = new DbTableAttribute();
        ValidateTableAttribute(tableAttribute, t);
        return tableAttribute;
    }
    public virtual void ValidateTableAttribute(DbTableAttribute tableAttribute, Type type)
    {
        if (string.IsNullOrWhiteSpace(tableAttribute.Name))
            tableAttribute.Name = type.Name;
    }
    public virtual Dictionary<string, object> GetColumns<Table>(Table model)
    {
        Type tableType = typeof(Table);
        Dictionary<string, object> dict = new Dictionary<string, object>();
        var propertyInfos = tableType.GetProperties(BindingFlags.Instance | BindingFlags.Public);
        foreach (PropertyInfo propertyInfo in propertyInfos)
        {
            var columnInfo = propertyInfo.GetCustomAttribute<DbColumnAttribute>();
            if (columnInfo == null) continue;
            dict.Add(columnInfo.Name, propertyInfo.GetValue(model));
        }
        return dict;
    }
    public virtual List<DbColumnAttribute> GetColumnAttributes<Table>() => GetColumnAttributes(typeof(Table));
    public virtual List<DbColumnAttribute> GetColumnAttributes(Type tableType)
    {
        List<DbColumnAttribute> dbColumnAttributes = new List<DbColumnAttribute>();
        var propertyInfos = tableType.GetProperties(BindingFlags.Instance | BindingFlags.Public);
        foreach (PropertyInfo propertyInfo in propertyInfos)
        {
            var columnInfo = propertyInfo.GetCustomAttribute<DbColumnAttribute>();
            if (columnInfo == null) continue;
            dbColumnAttributes.Add(columnInfo);
        }
        return dbColumnAttributes;
    }
    public virtual string NameOf<Table>() => NameOf(typeof(Table));
    public virtual string NameOf(Type type)
    {
        return GetDbTableAttribute(type).Name;
    }
    /// <summary>
    /// 
    /// </summary>
    /// <typeparam name="Table"></typeparam>
    /// <param name="propertyName">nameof(Table.Property)</param>
    /// <returns></returns>
    public virtual string NameOf<Table>(string propertyName) => NameOf(typeof(Table), propertyName);
    /// <summary>
    /// 
    /// </summary>
    /// <typeparam name="Table"></typeparam>
    /// <param name="propertyName">nameof(Table.Property)</param>
    /// <returns></returns>
    public virtual string NameOf(Type tableType, string propertyName)
    {
        var propertyInfos = tableType.GetProperties(BindingFlags.Instance | BindingFlags.Public);
        foreach (PropertyInfo propertyInfo in propertyInfos)
        {
            var columnInfo = propertyInfo.GetCustomAttribute<DbColumnAttribute>();
            if (columnInfo != null && propertyInfo.Name == propertyName) return columnInfo.Name;
        }
        return propertyName;
    }
}

Mysql:

查看代码
public class MySqlDbContext : DbContext<MySqlTransaction, MySqlConnection, MySqlCommand, MySqlParameter, MySqlDataAdapter, MySqlDataReader>
{
    public MySqlDbContext() { }
    public MySqlDbContext(Action<DbContext<MySqlTransaction, MySqlConnection, MySqlCommand, MySqlParameter, MySqlDataAdapter, MySqlDataReader>> callback) :base(callback) { }
    public override void Repair<Table>()
    {
        var table = GetDbTableAttribute<Table>();
        var tableType = typeof(Table);
        ValidateTableAttribute(table, tableType);
        bool exist = ExistTable(table.Name);
        if (!exist)
            CreateTable<Table>();
        else
        {
            RepairColumns<Table>();
            RepairIndex(tableType);
        }
    }
    public override int CreateTable<Table>()
    {
        Type tableType = typeof(Table);
        var tableAttribute = GetDbTableAttribute<Table>();
        StringBuilder sb = new StringBuilder($"create table `{tableAttribute.Name}`");
        StringBuilder columns = new StringBuilder();
        var propertyInfos = tableType.GetProperties(BindingFlags.Instance | BindingFlags.Public);
        foreach (PropertyInfo propertyInfo in propertyInfos)
        {
            var columnInfo = propertyInfo.GetCustomAttribute<DbColumnAttribute>();
            if (columnInfo == null) continue;
            columns.AppendLine($"\t`{columnInfo.Name}` {columnInfo.Desc},");
        }
        StringBuilder sbIndex = new StringBuilder();
        Dictionary<string, List<string>> indexDict = GetIndexs(tableType);
        foreach (var key in indexDict.Keys)
        {
            string indexType = "index";
            switch (key.ToLower().Trim())
            {
                case "primary":
                case "primary key":
                case "primarykey":
                case "unique":
                case "unique index":
                    continue;
                default:
                    break;
            }
            sbIndex.AppendLine($"\t{indexType} `{key}` (`{string.Join("`,`", indexDict[key])}`),");
        }
        string sIndexs = sbIndex.ToString();
        string sColumns = columns.ToString();
        if (sbIndex.Length > 0)
            sIndexs = sIndexs.Remove(sIndexs.LastIndexOf(','), 1);
        else
            sColumns = sColumns.Remove(sColumns.LastIndexOf(','), 1);
        string content = sColumns + sIndexs;
        sb.Append($" (\r\n{content})\r\n");
        if(!string.IsNullOrWhiteSpace(tableAttribute.Charset))
            sb.AppendLine($"default character set {tableAttribute.Charset}");
        if (!string.IsNullOrWhiteSpace(tableAttribute.Collate))
            sb.AppendLine($"collate {tableAttribute.Collate}");
        return ExecuteNonQuery(sb.ToString());
    }
    public override void ValidateTableAttribute(DbTableAttribute tableAttribute, Type type)
    {
        if(string.IsNullOrWhiteSpace(tableAttribute.Name))
            tableAttribute.Name = type.Name;
        if (string.IsNullOrWhiteSpace(tableAttribute.Charset))
            tableAttribute.Charset = "utf8mb4";
        if (string.IsNullOrWhiteSpace(tableAttribute.Collate))
            tableAttribute.Collate = "utf8mb4_unicode_ci";
    }
    public override bool ExistColumn(string column, string table)
    {
        Connect();
        var r = ExecuteScalar($"select 1 from information_schema.columns where table_schema='{Connection.Database}' and table_name ='{table}' and column_name='{column}';");
        return r != null && r.ToString() == "1";
    }
    public override bool ExistIndex(string tableName, string index, string columnName = null)
    {
        StringBuilder sql = new StringBuilder($"select count(*) from information_schema.statistics where table_schema = database() and table_name = '{tableName}' and index_name = '{index}'");
        if (!string.IsNullOrWhiteSpace(columnName))
            sql.Append($" and column_name='{columnName}'");
        object r = ExecuteScalar(sql.ToString());
        return r != null && (r is int num) && num > 0;
    }
    public override bool ExistTable(string table)
    {
        Connect();
        var r = ExecuteScalar($"select 1 from information_schema.tables where table_schema='{Connection.Database}' and table_name ='{table}';");
        return r != null && r.ToString() == "1";
    }
}

SQLite:

查看代码
public class SqliteDbContext : DbContext<SQLiteTransaction, SQLiteConnection, SQLiteCommand, SQLiteParameter, SQLiteDataAdapter, SQLiteDataReader>
{
    public SqliteDbContext() { }
    public SqliteDbContext(Action<DbContext<SQLiteTransaction, SQLiteConnection, SQLiteCommand, SQLiteParameter, SQLiteDataAdapter, SQLiteDataReader>> callback) : base(callback) { }
    public override void Connect(bool reconnect = false)
    {
        if (!reconnect && Connection != null && (Connection.State == ConnectionState.Open || Connection.State == ConnectionState.Connecting)) return;
        if (!Directory.Exists(Path.GetDirectoryName(connectStr)))
            Directory.CreateDirectory(Path.GetDirectoryName(connectStr));
        if (!File.Exists(connectStr))
            SQLiteConnection.CreateFile(connectStr);
        Connection = new SQLiteConnection(connectStr.ToLower().Contains("data source=") ? connectStr : "data source=" + connectStr);
        Connection.Open();
    }
    public override bool ExistColumn(string column, string table)
    {
        Connect();
        var r = ExecuteScalar($"select 1 from sqlite_master where type='table' and name='{table}' and sql like '%{column}%'");
        return r != null && r.ToString() == "1";
    }
    public override bool ExistIndex(string tableName, string index, string columnName = null)
    {
        StringBuilder sql = new StringBuilder($"select 1 from sqlite_master where type='index' and name='{index}'");
        if (!string.IsNullOrWhiteSpace(columnName))
            sql.Append($" and sql like '%{columnName}%'");
        object r = ExecuteScalar(sql.ToString());
        return r != null && (r is int num) && num > 0;
    }
    public override bool ExistTable(string table)
    {
        Connect();
        var r = ExecuteScalar($"select 1 from sqlite_master where type='table' and name='{table}'");
        return r != null && r.ToString() == "1";
    }
}

使用方法:

[DbTable(Name = "account")]
public class Account
{
    [DbColumn(Name = "id", Desc = "bigint primary key auto_increment", NotInsert = true)]
    public long Id { get; set; }
    [DbColumn(Name = "un", Desc = "varchar(50) not null unique", Index = "un")]
    public string Username { get; set; }
    [DbColumn(Name = "pwd", Desc = "varchar(20) not null")]
    public string Password { get; set; }
    [DbColumn(Name = "is_deleted", Desc = "int(1) default '0")]
    public bool Deleted { get; set; }
    [DbColumn(Name = "create_time", Desc = "timestamp default localtime")]
    public DateTime CreateTime { get; set; }
}
static void Main(string[] args)
{
    MySqlDbContext dbContext = new MySqlDbContext();
    dbContext.UseServer("Data Source=127.0.0.1; Database=tempdb; User ID=admin; Password=123;Charset=utf8mb4;");
    dbContext.Repair<Account>();
    var account = new Account();
    dbContext.Insert(account);
    var list = dbContext.GetList<Account>("select * from account limit 10");
    dbContext.DeleteByPrimaryKey<Account>(account.Id);
    Console.WriteLine("按任意键退出。");
    Console.ReadKey();
}

 

posted on 2024-08-13 15:33  HotSky  阅读(6)  评论(0编辑  收藏  举报