使用.net core efcore根据数据库结构自动生成实体类

源码

github,已更新最新代码

https://github.com/leoparddne/GenEntities/

使用的DB是mysql,所有先nuget一下mysql.data

创建t4模板文件

<#@ assembly name="System.Core"#>
<#@ assembly name="System.Data.Linq"#>
<#@ assembly name="EnvDTE"#>
<#@ assembly name="System.Xml"#>
<#@ assembly name="System.Xml.Linq"#>
<#@ import namespace="System"#>
<#@ import namespace="System.CodeDom"#>
<#@ import namespace="System.CodeDom.Compiler"#>
<#@ import namespace="System.Collections.Generic"#>
<#@ import namespace="System.Data.Linq"#>
<#@ import namespace="System.Data.Linq.Mapping"#>
<#@ import namespace="System.IO"#>
<#@ import namespace="System.Linq"#>
<#@ import namespace="System.Reflection"#>
<#@ import namespace="System.Text"#>
<#@ import namespace="System.Xml.Linq"#>
<#@ import namespace="Microsoft.VisualStudio.TextTemplating"#>
<#+
  
// Manager class records the various blocks so it can split them up
class Manager {
    private class Block {
        public String Name;
        public int Start, Length;
    }
  
    private Block currentBlock;
    private List<Block> files = new List<Block>();
    private Block footer = new Block();
    private Block header = new Block();
    private ITextTemplatingEngineHost host;
    private StringBuilder template;
    protected List<String> generatedFileNames = new List<String>();
  
    public static Manager Create(ITextTemplatingEngineHost host, StringBuilder template) {
        return (host is IServiceProvider) ? new VSManager(host, template) : new Manager(host, template);
    }
  
    public void StartNewFile(String name) {
        if (name == null)
            throw new ArgumentNullException("name");
        CurrentBlock = new Block { Name = name };
    }
  
    public void StartFooter() {
        CurrentBlock = footer;
    }
  
    public void StartHeader() {
        CurrentBlock = header;
    }
  
    public void EndBlock() {
        if (CurrentBlock == null)
            return;
        CurrentBlock.Length = template.Length - CurrentBlock.Start;
        if (CurrentBlock != header && CurrentBlock != footer)
            files.Add(CurrentBlock);
        currentBlock = null;
    }
  
    public virtual void Process(bool split) {
        if (split) {
            EndBlock();
            String headerText = template.ToString(header.Start, header.Length);
            String footerText = template.ToString(footer.Start, footer.Length);
            String outputPath = Path.GetDirectoryName(host.TemplateFile);
            files.Reverse();
            foreach(Block block in files) {
                String fileName = Path.Combine(outputPath, block.Name);
                String content = headerText + template.ToString(block.Start, block.Length) + footerText;
                generatedFileNames.Add(fileName);
                CreateFile(fileName, content);
                template.Remove(block.Start, block.Length);
            }
        }
    }
  
    protected virtual void CreateFile(String fileName, String content) {
        if (IsFileContentDifferent(fileName, content))
            File.WriteAllText(fileName, content);
    }
  
    public virtual String GetCustomToolNamespace(String fileName) {
        return null;
    }
  
    public virtual String DefaultProjectNamespace {
        get { return null; }
    }
  
    protected bool IsFileContentDifferent(String fileName, String newContent) {
        return !(File.Exists(fileName) && File.ReadAllText(fileName) == newContent);
    }
  
    private Manager(ITextTemplatingEngineHost host, StringBuilder template) {
        this.host = host;
        this.template = template;
    }
  
    private Block CurrentBlock {
        get { return currentBlock; }
        set {
            if (CurrentBlock != null)
                EndBlock();
            if (value != null)
                value.Start = template.Length;
            currentBlock = value;
        }
    }
  
    private class VSManager: Manager {
        private EnvDTE.ProjectItem templateProjectItem;
        private EnvDTE.DTE dte;
        private Action<String> checkOutAction;
        private Action<IEnumerable<String>> projectSyncAction;
  
        public override String DefaultProjectNamespace {
            get {
                return templateProjectItem.ContainingProject.Properties.Item("DefaultNamespace").Value.ToString();
            }
        }
  
        public override String GetCustomToolNamespace(string fileName) {
            return dte.Solution.FindProjectItem(fileName).Properties.Item("CustomToolNamespace").Value.ToString();
        }
  
        public override void Process(bool split) {
            if (templateProjectItem.ProjectItems == null)
                return;
            base.Process(split);
            projectSyncAction.EndInvoke(projectSyncAction.BeginInvoke(generatedFileNames, null, null));
        }
  
        protected override void CreateFile(String fileName, String content) {
            if (IsFileContentDifferent(fileName, content)) {
                CheckoutFileIfRequired(fileName);
                File.WriteAllText(fileName, content);
            }
        }
  
        internal VSManager(ITextTemplatingEngineHost host, StringBuilder template)
            : base(host, template) {
            var hostServiceProvider = (IServiceProvider) host;
            if (hostServiceProvider == null)
                throw new ArgumentNullException("Could not obtain IServiceProvider");
            dte = (EnvDTE.DTE) hostServiceProvider.GetService(typeof(EnvDTE.DTE));
            if (dte == null)
                throw new ArgumentNullException("Could not obtain DTE from host");
            templateProjectItem = dte.Solution.FindProjectItem(host.TemplateFile);
            checkOutAction = (String fileName) => dte.SourceControl.CheckOutItem(fileName);
            projectSyncAction = (IEnumerable<String> keepFileNames) => ProjectSync(templateProjectItem, keepFileNames);
        }
  
        private static void ProjectSync(EnvDTE.ProjectItem templateProjectItem, IEnumerable<String> keepFileNames) {
            var keepFileNameSet = new HashSet<String>(keepFileNames);
            var projectFiles = new Dictionary<String, EnvDTE.ProjectItem>();
            var originalFilePrefix = Path.GetFileNameWithoutExtension(templateProjectItem.get_FileNames(0)) + ".";
            foreach(EnvDTE.ProjectItem projectItem in templateProjectItem.ProjectItems)
                projectFiles.Add(projectItem.get_FileNames(0), projectItem);
  
            // Remove unused items from the project
            foreach(var pair in projectFiles)
                if (!keepFileNames.Contains(pair.Key) && !(Path.GetFileNameWithoutExtension(pair.Key) + ".").StartsWith(originalFilePrefix))
                    pair.Value.Delete();
  
            // Add missing files to the project
            foreach(String fileName in keepFileNameSet)
                if (!projectFiles.ContainsKey(fileName))
                    templateProjectItem.ProjectItems.AddFromFile(fileName);
        }
  
        private void CheckoutFileIfRequired(String fileName) {
            var sc = dte.SourceControl;
            if (sc != null && sc.IsItemUnderSCC(fileName) && !sc.IsItemCheckedOut(fileName))
                checkOutAction.EndInvoke(checkOutAction.BeginInvoke(fileName, null, null));
        }
    }
} #>

将后缀名改为ttinclude,修改名称为Manager.ttinclude

 

同上创建EntityHelper.tt,将后缀名改为ttinclude,注意第一行的<em id="__mceDel"></em>

<em id="__mceDel"></em>

<#@ assembly name="System.Core"#>
<#@ assembly name="System.Data"#>
<#@ assembly name="MySql.Data" #>
<#@ import namespace="System" #>
<#@ import namespace="System.Data" #>
<#@ import namespace="System.Collections.Generic" #>
<#@ import namespace="System.Linq" #>
<#@ import namespace="MySql.Data.MySqlClient" #>
<#+
    public class EntityHelper
    {
        public static List<Entity> GetEntities(string connectionString, List<string> databases)
        {
            var list = new List<Entity>();
            var conn = new MySqlConnection(connectionString);
            try
            {
                conn.Open();
                var dbs = string.Join("','", databases.ToArray());
                var cmd = string.Format(@"SELECT `information_schema`.`COLUMNS`.`TABLE_SCHEMA`
                                                    ,`information_schema`.`COLUMNS`.`TABLE_NAME`
                                                    ,`information_schema`.`COLUMNS`.`COLUMN_NAME`
                                                    ,`information_schema`.`COLUMNS`.`DATA_TYPE`
                                                    ,`information_schema`.`COLUMNS`.`COLUMN_COMMENT`
                                                FROM `information_schema`.`COLUMNS`
                                                WHERE `information_schema`.`COLUMNS`.`TABLE_SCHEMA` IN ('{0}') ", dbs);
                using (var reader = MySqlHelper.ExecuteReader(conn, cmd))
                {
                    while (reader.Read())
                    {
                        var db = reader["TABLE_SCHEMA"].ToString();
                        var table = reader["TABLE_NAME"].ToString();
                        var column = reader["COLUMN_NAME"].ToString();
                        var type =  reader["DATA_TYPE"].ToString();
                        var comment = reader["COLUMN_COMMENT"].ToString();
                        var entity = list.FirstOrDefault(x => x.EntityName == table);
                        if(entity == null)
                        {
                            entity = new Entity(table);
                            entity.Fields.Add(new Field
                            {
                                Name = column,
                                Type = GetCLRType(type),
                                Comment = comment
                            });
                             
                            list.Add(entity);
                        }
                        else
                        {
                            entity.Fields.Add(new Field
                            {
                                Name = column,
                                Type = GetCLRType(type),
                                Comment = comment
                            });
                        }
                    }
                }
            }
            finally
            {
                conn.Close();
            }
 
            return list;
        }
 
        public static string GetCLRType(string dbType)
        {
            switch(dbType)
            {
                case "tinyint":
                case "smallint":
                case "mediumint":
                case "int":
                case "integer":
                    return "int";
                case "double":
                    return "double";
                case "float":
                    return "float";
                case "decimal":
                    return "decimal";
                case "numeric":
                case "real":
                    return "decimal";
                case "bit":
                    return "bool";
                case "date":
                case "time":
                case "year":
                case "datetime":
                case "timestamp":
                    return "DateTime";
                case "tinyblob":
                case "blob":
                case "mediumblob":
                case "longblog":
                case "binary":
                case "varbinary":
                    return "byte[]";
                case "char":
                case "varchar":                   
                case "tinytext":
                case "text":
                case "mediumtext":
                case "longtext":
                    return "string";
                case "point":
                case "linestring":
                case "polygon":
                case "geometry":
                case "multipoint":
                case "multilinestring":
                case "multipolygon":
                case "geometrycollection":
                case "enum":
                case "set":
                default:
                    return dbType;
            }
        }
    }
 
    public class Entity
    {
        public Entity()
        {
            this.Fields = new List<Field>();
        }
 
        public Entity(string name)
            : this()
        {
            this.EntityName = name;
        }
 
        public string EntityName { get;set; }
        public List<Field> Fields { get;set; }
    }
 
    public class Field
    {
        public string Name { get;set; }
        public string Type { get;set; }
        public string Comment { get;set; }
    }
#>

任意创建一个.tt文件,修改相关的数据库连接配置部分,保存即可生成实体类

<#@ template debug="false" hostspecific="true" language="C#" #>
<#@ include file="Manager.ttinclude" #>
<#@ include file="EntityHelper.ttinclude" #>
<#
    // 是否是WCF服务模型
    bool serviceModel = false;
     
    // 数据库连接
    var connectionString = @"Server=172.0.0.1;port=3306;database=testDB;charset=utf8;uid=test;password=test";
 
    // 需要解析的数据库
    var database = new List<string> { "testDB" };
 
    // 文件版权信息
    var copyright = DateTime.Now.Year + " xxxx Enterprises All Rights Reserved";
    var version = Environment.Version;
    var author = "auto generated by T4";
 
    var manager = Manager.Create(Host, GenerationEnvironment);
    var entities = EntityHelper.GetEntities(connectionString, database);
 
    foreach(Entity entity in entities)
    {
        manager.StartNewFile(entity.EntityName + ".cs");
#>
//-----------------------------------------------------------------------
// <copyright file=" <#= entity.EntityName #>.cs" company="xxxx Enterprises">
// * Copyright (C) <#= copyright #>
// * version : <#= version #>
// * author  : <#= author #>
// * FileName: <#= entity.EntityName #>.cs
// * history : Created by T4 <#= DateTime.Now #>
// </copyright>
//-----------------------------------------------------------------------
using System;
<#    if(serviceModel)
    {
#>
using System.Runtime.Serialization;
<#
    }
#>
 
namespace CoreData
{
    /// <summary>
    /// <#= entity.EntityName #> Entity Model
    /// </summary>   
<#    if(serviceModel)
    {
#>
    [DataContract]
<#
    }
#>
    public class <#= entity.EntityName #>
    {
<#
        for(int i = 0; i < entity.Fields.Count; i++)
        {
            if(i ==0){
#>        /// <summary>
        /// <#= entity.Fields[i].Comment #>
        /// </summary>
<#    if(serviceModel)
    {
#>
        [DataMember]
<#
    }
#>
        public <#= entity.Fields[i].Type #> <#= entity.Fields[i].Name #> { get; set; }
<#
            }
            else{
#>   
        /// <summary>
        /// <#= entity.Fields[i].Comment #>
        /// </summary>
<#    if(serviceModel)
    {
#>
        [DataMember]
<#
    }
#>
        public <#= entity.Fields[i].Type #> <#= entity.Fields[i].Name #> { get; set; }
<#            }
        }
#>
    }
}
<#       
        manager.EndBlock();
    }
 
    manager.Process(true);
#>

 

posted @ 2018-11-26 19:11  Hey,Coder!  阅读(5564)  评论(0编辑  收藏  举报