golang的多协程实践

go语言以优异的并发特性而闻名,刚好手上有个小项目比较适合。

项目背景:

公司播控平台的数据存储包括MySQL和ElasticSearch(ES)两个部分,编辑、运营的数据首先保存在MySQL中,为了实现模糊搜索和产品关联推荐,特别增加了ES,ES中保存的是节目集的基本信息。

本项目是为了防止实时同步数据出现问题或者系统重新初始化时的全量数据同步而做的。项目主要是从MySQL读取所有的节目集数据写入到ES中。

项目特点:

因为节目集数量较大,不能一次性的读入内存,因此每次读出一部分记录写入ES。ORM使用的是beego。为了提高性能使用了协程,其中读MySQL的部分最大开启20个协程,ES写入部分开启了15个协程(因为ES分片设置的问题,5个协程和15个协程性能映像不大)。

项目主要包括三个文件:

1、PrdES_v3.go,项目的入口,负责协调MySQL读取和ES写入。

package main

import (
    "./db"
    "./es"
    //"encoding/json"
    "fmt"
    "reflect"
    "time"
)

type PrdES struct {
    DB *prd.Mysql
    ES *es.Elastic
}

// func (this *PrdES) Handle(result []*prd.Series) {
//     // for _, value := range result {
//     //     this.DB.FormatData(value)
//     //     //json, _ := json.Marshal(value)
//     //     //fmt.Println(string(json))
//     // }
//     //写入ES,以多线程的方式执行,最多保持5个线程
//     this.ES.DoBulk(result)
// }
func (this *PrdES) Run() {
    count := 50
    offset := 0
    maxCount := 20
    //create channel
    chs := make([]chan []*prd.Series, maxCount)
    selectCase := make([]reflect.SelectCase, maxCount)
    for i := 0; i < maxCount; i++ {
        offset = count * i
        fmt.Println("offset:", offset)
        //init channel
        chs[i] = make(chan []*prd.Series)
        //set select case
        selectCase[i].Dir = reflect.SelectRecv
        selectCase[i].Chan = reflect.ValueOf(chs[i])
        //运行
        go this.DB.GetData(offset, count, chs[i])
    }
    var result []*prd.Series
    for {
        //wait data return
        chosen, recv, ok := reflect.Select(selectCase)
        if ok {
            fmt.Println("channel id:", chosen)
            result = recv.Interface().([]*prd.Series)

            //读取数据从mysql
            go this.DB.GetData(offset, count, chs[chosen])

            //写入ES,以多线程的方式执行,最多保持15个线程
            this.ES.DoBulk(result)
            //update offset
            offset = offset + len(result)
            //判断是否到达数据尾部,最后一次查询
            if len(result) < count {
                fmt.Println("read end of DB")
                //等所有的任务执行完毕
                this.ES.Over()
                fmt.Println("MySQL Total:", this.DB.GetTotal(), ",Elastic Total:", this.ES.GetTotal())
                return

            }
        }
    }

}

func main() {
    s := time.Now()
    fmt.Println("start")
    pe := new(PrdES)

    pe.DB = prd.NewDB()
    pe.ES = es.NewES()
    //fmt.Println("mysql info:")
    //fmt.Println("ES info:")
    pe.Run()

    fmt.Println("time out:", time.Since(s).Seconds(), "(s)")
    fmt.Println("Over!")

}

 在run函数里可以看到使用了reflect.SelectCase,使用reflect.SelectCase的原因是读MySQL数据是多个协程,不可预计哪个会首先返回,selectCase是任何一个处理完毕reflect.Select函数就会返回,MySQL读取的数据放在channel中宕Select函数返回时chosen, recv, ok := reflect.Select(selectCase)判断ok是否未true            chosen代表的是协程id通过result = recv.Interface().([]*prd.Series)获得返回的数据,因为MySQL读取的数据是对象的结果集,因次使用recv.Interface函数,如果是简单类型可以使用recv.recvInt(),recv.recvString()等函数直接获取channel返回数据。 

这里通过counter控制协程的数量,也可以通过channel,用select的方式控制协程的数量,之所以用counter计数器的方式控制协程数量是我想知道同时有多少协程在运行。

注:此处channel可以不用创建数组形式,channel带回来的数据也没有顺序问题。

2、es.go,负责写入ES和es的写入协程调度

package es

import (
    "../db"
    //"encoding/json"
    "fmt"
    elastigo "github.com/mattbaird/elastigo/lib"
    //elastigo "github.com/Uncodin/elastigo/lib"
    //"github.com/Uncodin/elastigo/core"
    "time"
    //"bytes"
    "flag"
    "sync"
    //"github.com/fatih/structs"
)

var (
    //开发测试库
    //host = flag.String("host", "192.168.1.236", "Elasticsearch Host")
    //C平台线上
    host = flag.String("host", "192.168.100.23", "Elasticsearch Host")
    port = flag.String("port", "9200", "Elasticsearch port")
)

//indexor := core.NewBulkIndexorErrors(10, 60)
// func init() {
//     //connect to elasticsearch
//     fmt.Println("connecting  es")
//     //api.Domain = *host //"192.168.1.236"
//     //api.Port = "9300"

// }
//save thread count
var counter int

type Elastic struct {
    //Seq int64
    c         *elastigo.Conn
    lock      *sync.Mutex
    lockTotal *sync.Mutex
    wg        *sync.WaitGroup
    total     int64
}

func (this *Elastic) Conn() {
    this.c = elastigo.NewConn()
    this.c.Domain = *host
    this.c.Port = *port
    //NewClient(fmt.Sprintf("%s:%d", *host, *port))
}
func (this *Elastic) CreateLock() {
    this.lock = &sync.Mutex{}
    this.lockTotal = &sync.Mutex{}
    this.wg = &sync.WaitGroup{}
    counter = 0
    this.total = 0
}
func NewES() (es *Elastic) {
    //connect elastic
    es = new(Elastic)
    es.Conn()
    //create lock
    es.CreateLock()
    return es
}
func (this *Elastic) DoBulk(series []*prd.Series) {
    for true {
        this.lock.Lock()
        if counter < 25 {
            //跳出,执行任务
            break
        } else {
            this.lock.Unlock()
            //等待100毫秒
            //fmt.Println("wait counter less than 25, counter:", counter)
            time.Sleep(1e8)
        }
    }
    this.lock.Unlock()
    //执行任务
    go this.bulk(series, this.lock)
}
func (this *Elastic) Over() {
    this.wg.Wait()
    /*for {
        this.lock.Lock()
        if counter <= 0 {
            this.lock.Unlock()
            break
        }
        this.lock.Unlock()
    }
    */
}

func (this *Elastic) GetTotal() (t int64) {
    this.lockTotal.Lock()
    t = this.total
    this.lockTotal.Unlock()
    return t
}
func (this *Elastic) bulk(series []*prd.Series, lock *sync.Mutex) (succCount int64) {
    //增加计数器
    this.wg.Add(1)
    //减少计数器
    defer this.wg.Done()

    //加计数器
    lock.Lock()
    counter++
    fmt.Println("add task, coutner:", counter)
    lock.Unlock()

    //设置初始成功写入的数量
    succCount = 0

    for _, value := range series {
        //json, _ := json.Marshal(value)
        //fmt.Println(string(json))
        if value.ServiceGroup != nil {
            fmt.Println("series code:", value.Code, ",ServiceGroup:", value.ServiceGroup)

            resp, err := this.c.Index("guttv", "series", value.Code, nil, *value)

            if err != nil {
                panic(err)
            } else {
                //fmt.Println(value.Code + " write to ES succsessful!")
                fmt.Println(resp)
                succCount++
            }
        } else {
            fmt.Println("series code:", value.Code, "service group is null")
        }
    }

    //计数器减一
    lock.Lock()
    counter--
    fmt.Println("reduce task, coutner:", counter, ",success count:", succCount)
    lock.Unlock()

    this.lockTotal.Lock()
    this.total = this.total + succCount
    this.lockTotal.Unlock()
    return succCount
}

 在es.go里有两把锁lock和lockTotal,前者是针对counter变量,记录es正在写入的协程数量的;后者为记录总共写入多少条记录到es而增加的。

这里必须要提到的是Over函数,初步使用协程的容易忽略。golang的原则是当main函数运行结束后,所有正在运行的协程都会终止,因袭在MySQL读取数据完毕必须调用Over函数,等待所有的协程结束。这里使用sync.waiGrooup。每次协程启动执行下面两个语句:

//增加计数器
this.wg.Add(1)
//减少计数器,函数结束时自动执行
defer this.wg.Done()
Over函数中调用
wg.Wait()等待计数器为0时返回,否则就一直阻塞。当然读者也可以看到通过检查counter是否小于等于0也可以判断协程是否都结束(Over函数被注释的部分),显然使用waitGroup更优雅和高效。

 

3、db.go,负责MySQL数据的读取

package prd

import (
    "fmt"
    "github.com/astaxie/beego/orm"
    _ "github.com/go-sql-driver/mysql" // import your used driver
    "strings"
    "sync"
    "time"
)

func init() {
  
    orm.RegisterDataBase("default", "mysql", "@tcp(192.168.100.3306)/guttv_vod?charset=utf8", 30)

    orm.RegisterModelWithPrefix("t_", new(Series), new(Product), new(ServiceGroup))
    orm.RunSyncdb("default", false, false)
}

type Mysql struct {
    sql   string
    total int64
    lock  *sync.Mutex
}

func (this *Mysql) New() {
    //this.sql = "SELECT s.*, p.code ProductCode, p.name pName  FROM guttv_vod.t_series s inner join guttv_vod.t_product p on p.itemcode=s.code  and p.isdelete=0 limit ?,?"
    this.sql = "SELECT s.*, p.code ProductCode, p.name pName  FROM guttv_vod.t_series s , guttv_vod.t_product p where p.itemcode=s.code  and p.isdelete=0 limit ?,?"
    this.total = 0
    this.lock = &sync.Mutex{}
}
func NewDB() (db *Mysql) {
    db = new(Mysql)
    db.New()
    return db
}
func (this *Mysql) GetTotal() (t int64) {
    t = 0
    this.lock.Lock()
    t = this.total
    this.lock.Unlock()
    return t
}
func (this *Mysql) toTime(toBeCharge string) int64 {
    timeLayout := "2006-01-02 15:04:05"
    loc, _ := time.LoadLocation("Local")
    theTime, _ := time.ParseInLocation(timeLayout, toBeCharge, loc)
    sr := theTime.Unix()
    if sr < 0 {
        sr = 0
    }
    return sr
}
func (this *Mysql) getSGCode(seriesCode string) (result []string, num int64) {
    sql := "select distinct ref.servicegroupcode code  from t_servicegroup_reference_category ref "
    sql = sql + "left join t_category_product cp on cp.categorycode=ref.categorycode "
    sql = sql + "left join t_package pkg on pkg.code = cp.assetcode "
    sql = sql + "left join t_package_product pp on pp.parentcode=pkg.code "
    sql = sql + "left join t_product prd on prd.code = pp.assetcode "
    sql = sql + "where   prd.itemcode=?"
    o := orm.NewOrm()
    var sg []*ServiceGroup
    num, err := o.Raw(sql, seriesCode).QueryRows(&sg)

    if err == nil {
        //fmt.Println(num)
        for _, value := range sg {
            //fmt.Println(value.Code)
            result = append(result, value.Code)
        }

    } else {
        fmt.Println(err)
    }
    //fmt.Println(result)
    return result, num
}

func (this *Mysql) formatData(value *Series) {
    //设置业务分组数据
    sg, _ := this.getSGCode(value.Code)
    //fmt.Println(sg)
    value.ServiceGroup = []string{}
    value.ServiceGroup = sg[0:]
    //更改OnlineTime为整数
    value.OnlineTimeInt = this.toTime(value.OnlineTime)
    //分解地区
    value.OriginalCountryArr = strings.Split(value.OriginalCountry, "|")
    //分解二级分类
    value.ProgramType2Arr = strings.Split(value.ProgramType2, "|")
    //写入记录内容
    value.Description = strings.Replace(value.Description, "\n", "", -1)
}
func (this *Mysql) GetData(offset int, size int, ch chan []*Series) {
    var result []*Series
    o := orm.NewOrm()
    num, err := o.Raw(this.sql, offset, size).
        QueryRows(&result)
    if err != nil {
        fmt.Println("read DB err")
        panic(err)
        //return //err, nil
    }
    for _, value := range result {
        this.formatData(value)
        //json, _ := json.Marshal(value)
        //fmt.Println(string(json))
        //fmt.Println(value.ServiceGroup)
    }
    this.lock.Lock()
    this.total += num
    this.lock.Unlock()

    fmt.Println("read count :", num) //, "Total:", Total)
    //return nil, result
    ch <- result
}

 从项目上看。go语言开发还是比较简洁的,多协程实现也相对容易,但要求开发者必须对概念非常清晰,像select和selectCase理解和defer的理解要很到位,个人层经做过多年的C/C++程序员,从经验上看,C/C++的经验(多线程的理解)对运用go语言还是很有帮助的。

 

posted @ 2016-07-08 11:40  程序员老刘  阅读(9705)  评论(2编辑  收藏  举报