package nsqd

import (
    "bufio"
    "compress/flate"
    "crypto/tls"
    "fmt"
    "net"
    "sync"
    "sync/atomic"
    "time"

    "github.com/mreiferson/go-snappystream"
    "github.com/nsqio/nsq/internal/auth"
)

const defaultBufferSize = 16 * 1024

const (
    stateInit = iota
    stateDisconnected
    stateConnected
    stateSubscribed
    stateClosing
)

type identifyDataV2 struct {
    ShortID string `json:"short_id"` // TODO: deprecated, remove in 1.0
    LongID  string `json:"long_id"`  // TODO: deprecated, remove in 1.0

    ClientID            string `json:"client_id"`
    Hostname            string `json:"hostname"`
    HeartbeatInterval   int    `json:"heartbeat_interval"`
    OutputBufferSize    int    `json:"output_buffer_size"`
    OutputBufferTimeout int    `json:"output_buffer_timeout"`
    FeatureNegotiation  bool   `json:"feature_negotiation"`
    TLSv1               bool   `json:"tls_v1"`
    Deflate             bool   `json:"deflate"`
    DeflateLevel        int    `json:"deflate_level"`
    Snappy              bool   `json:"snappy"`
    SampleRate          int32  `json:"sample_rate"`
    UserAgent           string `json:"user_agent"`
    MsgTimeout          int    `json:"msg_timeout"`
}

type identifyEvent struct {
    OutputBufferTimeout time.Duration
    HeartbeatInterval   time.Duration
    SampleRate          int32
    MsgTimeout          time.Duration
}

type clientV2 struct {
    // 64bit atomic vars need to be first for proper alignment on 32bit platforms
    ReadyCount    int64
    InFlightCount int64
    MessageCount  uint64
    FinishCount   uint64
    RequeueCount  uint64

    writeLock sync.RWMutex
    metaLock  sync.RWMutex

    ID        int64
    ctx       *context
    UserAgent string

    // original connection
    net.Conn

    // connections based on negotiated features
    tlsConn     *tls.Conn
    flateWriter *flate.Writer

    // reading/writing interfaces
    Reader *bufio.Reader
    Writer *bufio.Writer

    OutputBufferSize    int
    OutputBufferTimeout time.Duration

    HeartbeatInterval time.Duration

    MsgTimeout time.Duration

    State          int32
    ConnectTime    time.Time
    Channel        *Channel
    ReadyStateChan chan int
    ExitChan       chan int

    ClientID string
    Hostname string

    SampleRate int32

    IdentifyEventChan chan identifyEvent
    SubEventChan      chan *Channel

    TLS     int32
    Snappy  int32
    Deflate int32

    // re-usable buffer for reading the 4-byte lengths off the wire
    lenBuf   [4]byte
    lenSlice []byte

    AuthSecret string
    AuthState  *auth.State
}

func newClientV2(id int64, conn net.Conn, ctx *context) *clientV2 {
    var identifier string
    if conn != nil {
        identifier, _, _ = net.SplitHostPort(conn.RemoteAddr().String())
    }

    c := &clientV2{
        ID:  id,
        ctx: ctx,

        Conn: conn,

        Reader: bufio.NewReaderSize(conn, defaultBufferSize),
        Writer: bufio.NewWriterSize(conn, defaultBufferSize),

        OutputBufferSize:    defaultBufferSize,
        OutputBufferTimeout: 250 * time.Millisecond,

        MsgTimeout: ctx.nsqd.getOpts().MsgTimeout,

        // ReadyStateChan has a buffer of 1 to guarantee that in the event
        // there is a race the state update is not lost
        ReadyStateChan: make(chan int, 1),
        ExitChan:       make(chan int),
        ConnectTime:    time.Now(),
        State:          stateInit,

        ClientID: identifier,
        Hostname: identifier,

        SubEventChan:      make(chan *Channel, 1),
        IdentifyEventChan: make(chan identifyEvent, 1),

        // heartbeats are client configurable but default to 30s
        HeartbeatInterval: ctx.nsqd.getOpts().ClientTimeout / 2,
    }
    c.lenSlice = c.lenBuf[:]
    return c
}

func (c *clientV2) String() string {
    return c.RemoteAddr().String()
}

func (c *clientV2) Identify(data identifyDataV2) error {
    c.ctx.nsqd.logf("[%s] IDENTIFY: %+v", c, data)

    // TODO: for backwards compatibility, remove in 1.0
    hostname := data.Hostname
    if hostname == "" {
        hostname = data.LongID
    }
    // TODO: for backwards compatibility, remove in 1.0
    clientID := data.ClientID
    if clientID == "" {
        clientID = data.ShortID
    }

    c.metaLock.Lock()
    c.ClientID = clientID
    c.Hostname = hostname
    c.UserAgent = data.UserAgent
    c.metaLock.Unlock()

    err := c.SetHeartbeatInterval(data.HeartbeatInterval)
    if err != nil {
        return err
    }

    err = c.SetOutputBufferSize(data.OutputBufferSize)
    if err != nil {
        return err
    }

    err = c.SetOutputBufferTimeout(data.OutputBufferTimeout)
    if err != nil {
        return err
    }

    err = c.SetSampleRate(data.SampleRate)
    if err != nil {
        return err
    }

    err = c.SetMsgTimeout(data.MsgTimeout)
    if err != nil {
        return err
    }

    ie := identifyEvent{
        OutputBufferTimeout: c.OutputBufferTimeout,
        HeartbeatInterval:   c.HeartbeatInterval,
        SampleRate:          c.SampleRate,
        MsgTimeout:          c.MsgTimeout,
    }

    // update the client's message pump
    select {
    case c.IdentifyEventChan <- ie:
    default:
    }

    return nil
}

func (c *clientV2) Stats() ClientStats {
    c.metaLock.RLock()
    // TODO: deprecated, remove in 1.0
    name := c.ClientID

    clientID := c.ClientID
    hostname := c.Hostname
    userAgent := c.UserAgent
    var identity string
    var identityURL string
    if c.AuthState != nil {
        identity = c.AuthState.Identity
        identityURL = c.AuthState.IdentityURL
    }
    c.metaLock.RUnlock()
    stats := ClientStats{
        // TODO: deprecated, remove in 1.0
        Name: name,

        Version:         "V2",
        RemoteAddress:   c.RemoteAddr().String(),
        ClientID:        clientID,
        Hostname:        hostname,
        UserAgent:       userAgent,
        State:           atomic.LoadInt32(&c.State),
        ReadyCount:      atomic.LoadInt64(&c.ReadyCount),
        InFlightCount:   atomic.LoadInt64(&c.InFlightCount),
        MessageCount:    atomic.LoadUint64(&c.MessageCount),
        FinishCount:     atomic.LoadUint64(&c.FinishCount),
        RequeueCount:    atomic.LoadUint64(&c.RequeueCount),
        ConnectTime:     c.ConnectTime.Unix(),
        SampleRate:      atomic.LoadInt32(&c.SampleRate),
        TLS:             atomic.LoadInt32(&c.TLS) == 1,
        Deflate:         atomic.LoadInt32(&c.Deflate) == 1,
        Snappy:          atomic.LoadInt32(&c.Snappy) == 1,
        Authed:          c.HasAuthorizations(),
        AuthIdentity:    identity,
        AuthIdentityURL: identityURL,
    }
    if stats.TLS {
        p := prettyConnectionState{c.tlsConn.ConnectionState()}
        stats.CipherSuite = p.GetCipherSuite()
        stats.TLSVersion = p.GetVersion()
        stats.TLSNegotiatedProtocol = p.NegotiatedProtocol
        stats.TLSNegotiatedProtocolIsMutual = p.NegotiatedProtocolIsMutual
    }
    return stats
}

// struct to convert from integers to the human readable strings
type prettyConnectionState struct {
    tls.ConnectionState
}

func (p *prettyConnectionState) GetCipherSuite() string {
    switch p.CipherSuite {
    case tls.TLS_RSA_WITH_RC4_128_SHA:
        return "TLS_RSA_WITH_RC4_128_SHA"
    case tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA:
        return "TLS_RSA_WITH_3DES_EDE_CBC_SHA"
    case tls.TLS_RSA_WITH_AES_128_CBC_SHA:
        return "TLS_RSA_WITH_AES_128_CBC_SHA"
    case tls.TLS_RSA_WITH_AES_256_CBC_SHA:
        return "TLS_RSA_WITH_AES_256_CBC_SHA"
    case tls.TLS_ECDHE_ECDSA_WITH_RC4_128_SHA:
        return "TLS_ECDHE_ECDSA_WITH_RC4_128_SHA"
    case tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA:
        return "TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA"
    case tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA:
        return "TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA"
    case tls.TLS_ECDHE_RSA_WITH_RC4_128_SHA:
        return "TLS_ECDHE_RSA_WITH_RC4_128_SHA"
    case tls.TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA:
        return "TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA"
    case tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA:
        return "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA"
    case tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA:
        return "TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA"
    case tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256:
        return "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256"
    case tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256:
        return "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256"
    }
    return fmt.Sprintf("Unknown %d", p.CipherSuite)
}

func (p *prettyConnectionState) GetVersion() string {
    switch p.Version {
    case tls.VersionSSL30:
        return "SSL30"
    case tls.VersionTLS10:
        return "TLS1.0"
    case tls.VersionTLS11:
        return "TLS1.1"
    case tls.VersionTLS12:
        return "TLS1.2"
    default:
        return fmt.Sprintf("Unknown %d", p.Version)
    }
}

func (c *clientV2) IsReadyForMessages() bool {
    if c.Channel.IsPaused() {
        return false
    }

    readyCount := atomic.LoadInt64(&c.ReadyCount)
    inFlightCount := atomic.LoadInt64(&c.InFlightCount)

    if c.ctx.nsqd.getOpts().Verbose {
        c.ctx.nsqd.logf("[%s] state rdy: %4d inflt: %4d",
            c, readyCount, inFlightCount)
    }

    if inFlightCount >= readyCount || readyCount <= 0 {
        return false
    }

    return true
}

func (c *clientV2) SetReadyCount(count int64) {
    atomic.StoreInt64(&c.ReadyCount, count)
    c.tryUpdateReadyState()
}

func (c *clientV2) tryUpdateReadyState() {
    // you can always *try* to write to ReadyStateChan because in the cases
    // where you cannot the message pump loop would have iterated anyway.
    // the atomic integer operations guarantee correctness of the value.
    select {
    case c.ReadyStateChan <- 1:
    default:
    }
}

func (c *clientV2) FinishedMessage() {
    atomic.AddUint64(&c.FinishCount, 1)
    atomic.AddInt64(&c.InFlightCount, -1)
    c.tryUpdateReadyState()
}

func (c *clientV2) Empty() {
    atomic.StoreInt64(&c.InFlightCount, 0)
    c.tryUpdateReadyState()
}

func (c *clientV2) SendingMessage() {
    atomic.AddInt64(&c.InFlightCount, 1)
    atomic.AddUint64(&c.MessageCount, 1)
}

func (c *clientV2) TimedOutMessage() {
    atomic.AddInt64(&c.InFlightCount, -1)
    c.tryUpdateReadyState()
}

func (c *clientV2) RequeuedMessage() {
    atomic.AddUint64(&c.RequeueCount, 1)
    atomic.AddInt64(&c.InFlightCount, -1)
    c.tryUpdateReadyState()
}

func (c *clientV2) StartClose() {
    // Force the client into ready 0
    c.SetReadyCount(0)
    // mark this client as closing
    atomic.StoreInt32(&c.State, stateClosing)
}

func (c *clientV2) Pause() {
    c.tryUpdateReadyState()
}

func (c *clientV2) UnPause() {
    c.tryUpdateReadyState()
}

func (c *clientV2) SetHeartbeatInterval(desiredInterval int) error {
    c.writeLock.Lock()
    defer c.writeLock.Unlock()

    switch {
    case desiredInterval == -1:
        c.HeartbeatInterval = 0
    case desiredInterval == 0:
        // do nothing (use default)
    case desiredInterval >= 1000 &&
        desiredInterval <= int(c.ctx.nsqd.getOpts().MaxHeartbeatInterval/time.Millisecond):
        c.HeartbeatInterval = time.Duration(desiredInterval) * time.Millisecond
    default:
        return fmt.Errorf("heartbeat interval (%d) is invalid", desiredInterval)
    }

    return nil
}

func (c *clientV2) SetOutputBufferSize(desiredSize int) error {
    var size int

    switch {
    case desiredSize == -1:
        // effectively no buffer (every write will go directly to the wrapped net.Conn)
        size = 1
    case desiredSize == 0:
        // do nothing (use default)
    case desiredSize >= 64 && desiredSize <= int(c.ctx.nsqd.getOpts().MaxOutputBufferSize):
        size = desiredSize
    default:
        return fmt.Errorf("output buffer size (%d) is invalid", desiredSize)
    }

    if size > 0 {
        c.writeLock.Lock()
        defer c.writeLock.Unlock()
        c.OutputBufferSize = size
        err := c.Writer.Flush()
        if err != nil {
            return err
        }
        c.Writer = bufio.NewWriterSize(c.Conn, size)
    }

    return nil
}

func (c *clientV2) SetOutputBufferTimeout(desiredTimeout int) error {
    c.writeLock.Lock()
    defer c.writeLock.Unlock()

    switch {
    case desiredTimeout == -1:
        c.OutputBufferTimeout = 0
    case desiredTimeout == 0:
        // do nothing (use default)
    case desiredTimeout >= 1 &&
        desiredTimeout <= int(c.ctx.nsqd.getOpts().MaxOutputBufferTimeout/time.Millisecond):
        c.OutputBufferTimeout = time.Duration(desiredTimeout) * time.Millisecond
    default:
        return fmt.Errorf("output buffer timeout (%d) is invalid", desiredTimeout)
    }

    return nil
}

func (c *clientV2) SetSampleRate(sampleRate int32) error {
    if sampleRate < 0 || sampleRate > 99 {
        return fmt.Errorf("sample rate (%d) is invalid", sampleRate)
    }
    atomic.StoreInt32(&c.SampleRate, sampleRate)
    return nil
}

func (c *clientV2) SetMsgTimeout(msgTimeout int) error {
    c.writeLock.Lock()
    defer c.writeLock.Unlock()

    switch {
    case msgTimeout == 0:
        // do nothing (use default)
    case msgTimeout >= 1000 &&
        msgTimeout <= int(c.ctx.nsqd.getOpts().MaxMsgTimeout/time.Millisecond):
        c.MsgTimeout = time.Duration(msgTimeout) * time.Millisecond
    default:
        return fmt.Errorf("msg timeout (%d) is invalid", msgTimeout)
    }

    return nil
}

func (c *clientV2) UpgradeTLS() error {
    c.writeLock.Lock()
    defer c.writeLock.Unlock()

    tlsConn := tls.Server(c.Conn, c.ctx.nsqd.tlsConfig)
    tlsConn.SetDeadline(time.Now().Add(5 * time.Second))
    err := tlsConn.Handshake()
    if err != nil {
        return err
    }
    c.tlsConn = tlsConn

    c.Reader = bufio.NewReaderSize(c.tlsConn, defaultBufferSize)
    c.Writer = bufio.NewWriterSize(c.tlsConn, c.OutputBufferSize)

    atomic.StoreInt32(&c.TLS, 1)

    return nil
}

func (c *clientV2) UpgradeDeflate(level int) error {
    c.writeLock.Lock()
    defer c.writeLock.Unlock()

    conn := c.Conn
    if c.tlsConn != nil {
        conn = c.tlsConn
    }

    c.Reader = bufio.NewReaderSize(flate.NewReader(conn), defaultBufferSize)

    fw, _ := flate.NewWriter(conn, level)
    c.flateWriter = fw
    c.Writer = bufio.NewWriterSize(fw, c.OutputBufferSize)

    atomic.StoreInt32(&c.Deflate, 1)

    return nil
}

func (c *clientV2) UpgradeSnappy() error {
    c.writeLock.Lock()
    defer c.writeLock.Unlock()

    conn := c.Conn
    if c.tlsConn != nil {
        conn = c.tlsConn
    }

    c.Reader = bufio.NewReaderSize(snappystream.NewReader(conn, snappystream.SkipVerifyChecksum), defaultBufferSize)
    c.Writer = bufio.NewWriterSize(snappystream.NewWriter(conn), c.OutputBufferSize)

    atomic.StoreInt32(&c.Snappy, 1)

    return nil
}

func (c *clientV2) Flush() error {
    var zeroTime time.Time
    if c.HeartbeatInterval > 0 {
        c.SetWriteDeadline(time.Now().Add(c.HeartbeatInterval))
    } else {
        c.SetWriteDeadline(zeroTime)
    }

    err := c.Writer.Flush()
    if err != nil {
        return err
    }

    if c.flateWriter != nil {
        return c.flateWriter.Flush()
    }

    return nil
}

func (c *clientV2) QueryAuthd() error {
    remoteIP, _, err := net.SplitHostPort(c.String())
    if err != nil {
        return err
    }

    tls := atomic.LoadInt32(&c.TLS) == 1
    tlsEnabled := "false"
    if tls {
        tlsEnabled = "true"
    }

    authState, err := auth.QueryAnyAuthd(c.ctx.nsqd.getOpts().AuthHTTPAddresses,
        remoteIP, tlsEnabled, c.AuthSecret, c.ctx.nsqd.getOpts().HTTPClientConnectTimeout,
        c.ctx.nsqd.getOpts().HTTPClientRequestTimeout)
    if err != nil {
        return err
    }
    c.AuthState = authState
    return nil
}

func (c *clientV2) Auth(secret string) error {
    c.AuthSecret = secret
    return c.QueryAuthd()
}

func (c *clientV2) IsAuthorized(topic, channel string) (bool, error) {
    if c.AuthState == nil {
        return false, nil
    }
    if c.AuthState.IsExpired() {
        err := c.QueryAuthd()
        if err != nil {
            return false, err
        }
    }
    if c.AuthState.IsAllowed(topic, channel) {
        return true, nil
    }
    return false, nil
}

func (c *clientV2) HasAuthorizations() bool {
    if c.AuthState != nil {
        return len(c.AuthState.Authorizations) != 0
    }
    return false
}