package nsqd
import (
"bytes"
"encoding/binary"
"encoding/json"
"errors"
"fmt"
"io"
"math"
"math/rand"
"net"
"sync/atomic"
"time"
"unsafe"
"github.com/nsqio/nsq/internal/protocol"
"github.com/nsqio/nsq/internal/version"
)
const maxTimeout = time.Hour
const (
frameTypeResponse int32 = 0
frameTypeError int32 = 1
frameTypeMessage int32 = 2
)
var separatorBytes = []byte(" ")
var heartbeatBytes = []byte("_heartbeat_")
var okBytes = []byte("OK")
type protocolV2 struct {
ctx *context
}
func (p *protocolV2) IOLoop(conn net.Conn) error {
var err error
var line []byte
var zeroTime time.Time
clientID := atomic.AddInt64(&p.ctx.nsqd.clientIDSequence, 1)
client := newClientV2(clientID, conn, p.ctx)
// synchronize the startup of messagePump in order
// to guarantee that it gets a chance to initialize
// goroutine local state derived from client attributes
// and avoid a potential race with IDENTIFY (where a client
// could have changed or disabled said attributes)
messagePumpStartedChan := make(chan bool)
go p.messagePump(client, messagePumpStartedChan)
<-messagePumpStartedChan
for {
if client.HeartbeatInterval > 0 {
client.SetReadDeadline(time.Now().Add(client.HeartbeatInterval * 2))
} else {
client.SetReadDeadline(zeroTime)
}
// ReadSlice does not allocate new space for the data each request
// ie. the returned slice is only valid until the next call to it
line, err = client.Reader.ReadSlice('\n')
if err != nil {
if err == io.EOF {
err = nil
} else {
err = fmt.Errorf("failed to read command - %s", err)
}
break
}
// trim the '\n'
line = line[:len(line)-1]
// optionally trim the '\r'
if len(line) > 0 && line[len(line)-1] == '\r' {
line = line[:len(line)-1]
}
params := bytes.Split(line, separatorBytes)
if p.ctx.nsqd.getOpts().Verbose {
p.ctx.nsqd.logf("PROTOCOL(V2): [%s] %s", client, params)
}
var response []byte
response, err = p.Exec(client, params)
if err != nil {
ctx := ""
if parentErr := err.(protocol.ChildErr).Parent(); parentErr != nil {
ctx = " - " + parentErr.Error()
}
p.ctx.nsqd.logf("ERROR: [%s] - %s%s", client, err, ctx)
sendErr := p.Send(client, frameTypeError, []byte(err.Error()))
if sendErr != nil {
p.ctx.nsqd.logf("ERROR: [%s] - %s%s", client, sendErr, ctx)
break
}
// errors of type FatalClientErr should forceably close the connection
if _, ok := err.(*protocol.FatalClientErr); ok {
break
}
continue
}
if response != nil {
err = p.Send(client, frameTypeResponse, response)
if err != nil {
err = fmt.Errorf("failed to send response - %s", err)
break
}
}
}
p.ctx.nsqd.logf("PROTOCOL(V2): [%s] exiting ioloop", client)
conn.Close()
close(client.ExitChan)
if client.Channel != nil {
client.Channel.RemoveClient(client.ID)
}
return err
}
func (p *protocolV2) SendMessage(client *clientV2, msg *Message, buf *bytes.Buffer) error {
if p.ctx.nsqd.getOpts().Verbose {
p.ctx.nsqd.logf("PROTOCOL(V2): writing msg(%s) to client(%s) - %s",
msg.ID, client, msg.Body)
}
buf.Reset()
_, err := msg.WriteTo(buf)
if err != nil {
return err
}
err = p.Send(client, frameTypeMessage, buf.Bytes())
if err != nil {
return err
}
return nil
}
func (p *protocolV2) Send(client *clientV2, frameType int32, data []byte) error {
client.writeLock.Lock()
var zeroTime time.Time
if client.HeartbeatInterval > 0 {
client.SetWriteDeadline(time.Now().Add(client.HeartbeatInterval))
} else {
client.SetWriteDeadline(zeroTime)
}
_, err := protocol.SendFramedResponse(client.Writer, frameType, data)
if err != nil {
client.writeLock.Unlock()
return err
}
if frameType != frameTypeMessage {
err = client.Flush()
}
client.writeLock.Unlock()
return err
}
func (p *protocolV2) Exec(client *clientV2, params [][]byte) ([]byte, error) {
if bytes.Equal(params[0], []byte("IDENTIFY")) {
return p.IDENTIFY(client, params)
}
err := enforceTLSPolicy(client, p, params[0])
if err != nil {
return nil, err
}
switch {
case bytes.Equal(params[0], []byte("FIN")):
return p.FIN(client, params)
case bytes.Equal(params[0], []byte("RDY")):
return p.RDY(client, params)
case bytes.Equal(params[0], []byte("REQ")):
return p.REQ(client, params)
case bytes.Equal(params[0], []byte("PUB")):
return p.PUB(client, params)
case bytes.Equal(params[0], []byte("MPUB")):
return p.MPUB(client, params)
case bytes.Equal(params[0], []byte("DPUB")):
return p.DPUB(client, params)
case bytes.Equal(params[0], []byte("NOP")):
return p.NOP(client, params)
case bytes.Equal(params[0], []byte("TOUCH")):
return p.TOUCH(client, params)
case bytes.Equal(params[0], []byte("SUB")):
return p.SUB(client, params)
case bytes.Equal(params[0], []byte("CLS")):
return p.CLS(client, params)
case bytes.Equal(params[0], []byte("AUTH")):
return p.AUTH(client, params)
}
return nil, protocol.NewFatalClientErr(nil, "E_INVALID", fmt.Sprintf("invalid command %s", params[0]))
}
func (p *protocolV2) messagePump(client *clientV2, startedChan chan bool) {
var err error
var buf bytes.Buffer
var memoryMsgChan chan *Message
var backendMsgChan chan []byte
var subChannel *Channel
// NOTE: `flusherChan` is used to bound message latency for
// the pathological case of a channel on a low volume topic
// with >1 clients having >1 RDY counts
var flusherChan <-chan time.Time
var sampleRate int32
subEventChan := client.SubEventChan
identifyEventChan := client.IdentifyEventChan
outputBufferTicker := time.NewTicker(client.OutputBufferTimeout)
heartbeatTicker := time.NewTicker(client.HeartbeatInterval)
heartbeatChan := heartbeatTicker.C
msgTimeout := client.MsgTimeout
// v2 opportunistically buffers data to clients to reduce write system calls
// we force flush in two cases:
// 1. when the client is not ready to receive messages
// 2. we're buffered and the channel has nothing left to send us
// (ie. we would block in this loop anyway)
//
flushed := true
// signal to the goroutine that started the messagePump
// that we've started up
close(startedChan)
for {
if subChannel == nil || !client.IsReadyForMessages() {
// the client is not ready to receive messages...
memoryMsgChan = nil
backendMsgChan = nil
flusherChan = nil
// force flush
client.writeLock.Lock()
err = client.Flush()
client.writeLock.Unlock()
if err != nil {
goto exit
}
flushed = true
} else if flushed {
// last iteration we flushed...
// do not select on the flusher ticker channel
memoryMsgChan = subChannel.memoryMsgChan
backendMsgChan = subChannel.backend.ReadChan()
flusherChan = nil
} else {
// we're buffered (if there isn't any more data we should flush)...
// select on the flusher ticker channel, too
memoryMsgChan = subChannel.memoryMsgChan
backendMsgChan = subChannel.backend.ReadChan()
flusherChan = outputBufferTicker.C
}
select {
case <-flusherChan:
// if this case wins, we're either starved
// or we won the race between other channels...
// in either case, force flush
client.writeLock.Lock()
err = client.Flush()
client.writeLock.Unlock()
if err != nil {
goto exit
}
flushed = true
case <-client.ReadyStateChan:
case subChannel = <-subEventChan:
// you can't SUB anymore
subEventChan = nil
case identifyData := <-identifyEventChan:
// you can't IDENTIFY anymore
identifyEventChan = nil
outputBufferTicker.Stop()
if identifyData.OutputBufferTimeout > 0 {
outputBufferTicker = time.NewTicker(identifyData.OutputBufferTimeout)
}
heartbeatTicker.Stop()
heartbeatChan = nil
if identifyData.HeartbeatInterval > 0 {
heartbeatTicker = time.NewTicker(identifyData.HeartbeatInterval)
heartbeatChan = heartbeatTicker.C
}
if identifyData.SampleRate > 0 {
sampleRate = identifyData.SampleRate
}
msgTimeout = identifyData.MsgTimeout
case <-heartbeatChan:
err = p.Send(client, frameTypeResponse, heartbeatBytes)
if err != nil {
goto exit
}
case b := <-backendMsgChan:
if sampleRate > 0 && rand.Int31n(100) > sampleRate {
continue
}
msg, err := decodeMessage(b)
if err != nil {
p.ctx.nsqd.logf("ERROR: failed to decode message - %s", err)
continue
}
msg.Attempts++
subChannel.StartInFlightTimeout(msg, client.ID, msgTimeout)
client.SendingMessage()
err = p.SendMessage(client, msg, &buf)
if err != nil {
goto exit
}
flushed = false
case msg := <-memoryMsgChan:
if sampleRate > 0 && rand.Int31n(100) > sampleRate {
continue
}
msg.Attempts++
subChannel.StartInFlightTimeout(msg, client.ID, msgTimeout)
client.SendingMessage()
err = p.SendMessage(client, msg, &buf)
if err != nil {
goto exit
}
flushed = false
case <-client.ExitChan:
goto exit
}
}
exit:
p.ctx.nsqd.logf("PROTOCOL(V2): [%s] exiting messagePump", client)
heartbeatTicker.Stop()
outputBufferTicker.Stop()
if err != nil {
p.ctx.nsqd.logf("PROTOCOL(V2): [%s] messagePump error - %s", client, err)
}
}
func (p *protocolV2) IDENTIFY(client *clientV2, params [][]byte) ([]byte, error) {
var err error
if atomic.LoadInt32(&client.State) != stateInit {
return nil, protocol.NewFatalClientErr(nil, "E_INVALID", "cannot IDENTIFY in current state")
}
bodyLen, err := readLen(client.Reader, client.lenSlice)
if err != nil {
return nil, protocol.NewFatalClientErr(err, "E_BAD_BODY", "IDENTIFY failed to read body size")
}
if int64(bodyLen) > p.ctx.nsqd.getOpts().MaxBodySize {
return nil, protocol.NewFatalClientErr(nil, "E_BAD_BODY",
fmt.Sprintf("IDENTIFY body too big %d > %d", bodyLen, p.ctx.nsqd.getOpts().MaxBodySize))
}
if bodyLen <= 0 {
return nil, protocol.NewFatalClientErr(nil, "E_BAD_BODY",
fmt.Sprintf("IDENTIFY invalid body size %d", bodyLen))
}
body := make([]byte, bodyLen)
_, err = io.ReadFull(client.Reader, body)
if err != nil {
return nil, protocol.NewFatalClientErr(err, "E_BAD_BODY", "IDENTIFY failed to read body")
}
// body is a json structure with producer information
var identifyData identifyDataV2
err = json.Unmarshal(body, &identifyData)
if err != nil {
return nil, protocol.NewFatalClientErr(err, "E_BAD_BODY", "IDENTIFY failed to decode JSON body")
}
if p.ctx.nsqd.getOpts().Verbose {
p.ctx.nsqd.logf("PROTOCOL(V2): [%s] %+v", client, identifyData)
}
err = client.Identify(identifyData)
if err != nil {
return nil, protocol.NewFatalClientErr(err, "E_BAD_BODY", "IDENTIFY "+err.Error())
}
// bail out early if we're not negotiating features
if !identifyData.FeatureNegotiation {
return okBytes, nil
}
tlsv1 := p.ctx.nsqd.tlsConfig != nil && identifyData.TLSv1
deflate := p.ctx.nsqd.getOpts().DeflateEnabled && identifyData.Deflate
deflateLevel := 0
if deflate {
if identifyData.DeflateLevel <= 0 {
deflateLevel = 6
}
deflateLevel = int(math.Min(float64(deflateLevel), float64(p.ctx.nsqd.getOpts().MaxDeflateLevel)))
}
snappy := p.ctx.nsqd.getOpts().SnappyEnabled && identifyData.Snappy
if deflate && snappy {
return nil, protocol.NewFatalClientErr(nil, "E_IDENTIFY_FAILED", "cannot enable both deflate and snappy compression")
}
resp, err := json.Marshal(struct {
MaxRdyCount int64 `json:"max_rdy_count"`
Version string `json:"version"`
MaxMsgTimeout int64 `json:"max_msg_timeout"`
MsgTimeout int64 `json:"msg_timeout"`
TLSv1 bool `json:"tls_v1"`
Deflate bool `json:"deflate"`
DeflateLevel int `json:"deflate_level"`
MaxDeflateLevel int `json:"max_deflate_level"`
Snappy bool `json:"snappy"`
SampleRate int32 `json:"sample_rate"`
AuthRequired bool `json:"auth_required"`
OutputBufferSize int `json:"output_buffer_size"`
OutputBufferTimeout int64 `json:"output_buffer_timeout"`
}{
MaxRdyCount: p.ctx.nsqd.getOpts().MaxRdyCount,
Version: version.Binary,
MaxMsgTimeout: int64(p.ctx.nsqd.getOpts().MaxMsgTimeout / time.Millisecond),
MsgTimeout: int64(client.MsgTimeout / time.Millisecond),
TLSv1: tlsv1,
Deflate: deflate,
DeflateLevel: deflateLevel,
MaxDeflateLevel: p.ctx.nsqd.getOpts().MaxDeflateLevel,
Snappy: snappy,
SampleRate: client.SampleRate,
AuthRequired: p.ctx.nsqd.IsAuthEnabled(),
OutputBufferSize: client.OutputBufferSize,
OutputBufferTimeout: int64(client.OutputBufferTimeout / time.Millisecond),
})
if err != nil {
return nil, protocol.NewFatalClientErr(err, "E_IDENTIFY_FAILED", "IDENTIFY failed "+err.Error())
}
err = p.Send(client, frameTypeResponse, resp)
if err != nil {
return nil, protocol.NewFatalClientErr(err, "E_IDENTIFY_FAILED", "IDENTIFY failed "+err.Error())
}
if tlsv1 {
p.ctx.nsqd.logf("PROTOCOL(V2): [%s] upgrading connection to TLS", client)
err = client.UpgradeTLS()
if err != nil {
return nil, protocol.NewFatalClientErr(err, "E_IDENTIFY_FAILED", "IDENTIFY failed "+err.Error())
}
err = p.Send(client, frameTypeResponse, okBytes)
if err != nil {
return nil, protocol.NewFatalClientErr(err, "E_IDENTIFY_FAILED", "IDENTIFY failed "+err.Error())
}
}
if snappy {
p.ctx.nsqd.logf("PROTOCOL(V2): [%s] upgrading connection to snappy", client)
err = client.UpgradeSnappy()
if err != nil {
return nil, protocol.NewFatalClientErr(err, "E_IDENTIFY_FAILED", "IDENTIFY failed "+err.Error())
}
err = p.Send(client, frameTypeResponse, okBytes)
if err != nil {
return nil, protocol.NewFatalClientErr(err, "E_IDENTIFY_FAILED", "IDENTIFY failed "+err.Error())
}
}
if deflate {
p.ctx.nsqd.logf("PROTOCOL(V2): [%s] upgrading connection to deflate", client)
err = client.UpgradeDeflate(deflateLevel)
if err != nil {
return nil, protocol.NewFatalClientErr(err, "E_IDENTIFY_FAILED", "IDENTIFY failed "+err.Error())
}
err = p.Send(client, frameTypeResponse, okBytes)
if err != nil {
return nil, protocol.NewFatalClientErr(err, "E_IDENTIFY_FAILED", "IDENTIFY failed "+err.Error())
}
}
return nil, nil
}
func (p *protocolV2) AUTH(client *clientV2, params [][]byte) ([]byte, error) {
if atomic.LoadInt32(&client.State) != stateInit {
return nil, protocol.NewFatalClientErr(nil, "E_INVALID", "cannot AUTH in current state")
}
if len(params) != 1 {
return nil, protocol.NewFatalClientErr(nil, "E_INVALID", "AUTH invalid number of parameters")
}
bodyLen, err := readLen(client.Reader, client.lenSlice)
if err != nil {
return nil, protocol.NewFatalClientErr(err, "E_BAD_BODY", "AUTH failed to read body size")
}
if int64(bodyLen) > p.ctx.nsqd.getOpts().MaxBodySize {
return nil, protocol.NewFatalClientErr(nil, "E_BAD_BODY",
fmt.Sprintf("AUTH body too big %d > %d", bodyLen, p.ctx.nsqd.getOpts().MaxBodySize))
}
if bodyLen <= 0 {
return nil, protocol.NewFatalClientErr(nil, "E_BAD_BODY",
fmt.Sprintf("AUTH invalid body size %d", bodyLen))
}
body := make([]byte, bodyLen)
_, err = io.ReadFull(client.Reader, body)
if err != nil {
return nil, protocol.NewFatalClientErr(err, "E_BAD_BODY", "AUTH failed to read body")
}
if client.HasAuthorizations() {
return nil, protocol.NewFatalClientErr(nil, "E_INVALID", "AUTH Already set")
}
if !client.ctx.nsqd.IsAuthEnabled() {
return nil, protocol.NewFatalClientErr(err, "E_AUTH_DISABLED", "AUTH Disabled")
}
if err := client.Auth(string(body)); err != nil {
// we don't want to leak errors contacting the auth server to untrusted clients
p.ctx.nsqd.logf("PROTOCOL(V2): [%s] Auth Failed %s", client, err)
return nil, protocol.NewFatalClientErr(err, "E_AUTH_FAILED", "AUTH failed")
}
if !client.HasAuthorizations() {
return nil, protocol.NewFatalClientErr(nil, "E_UNAUTHORIZED", "AUTH No authorizations found")
}
resp, err := json.Marshal(struct {
Identity string `json:"identity"`
IdentityURL string `json:"identity_url"`
PermissionCount int `json:"permission_count"`
}{
Identity: client.AuthState.Identity,
IdentityURL: client.AuthState.IdentityURL,
PermissionCount: len(client.AuthState.Authorizations),
})
if err != nil {
return nil, protocol.NewFatalClientErr(err, "E_AUTH_ERROR", "AUTH error "+err.Error())
}
err = p.Send(client, frameTypeResponse, resp)
if err != nil {
return nil, protocol.NewFatalClientErr(err, "E_AUTH_ERROR", "AUTH error "+err.Error())
}
return nil, nil
}
func (p *protocolV2) CheckAuth(client *clientV2, cmd, topicName, channelName string) error {
// if auth is enabled, the client must have authorized already
// compare topic/channel against cached authorization data (refetching if expired)
if client.ctx.nsqd.IsAuthEnabled() {
if !client.HasAuthorizations() {
return protocol.NewFatalClientErr(nil, "E_AUTH_FIRST",
fmt.Sprintf("AUTH required before %s", cmd))
}
ok, err := client.IsAuthorized(topicName, channelName)
if err != nil {
// we don't want to leak errors contacting the auth server to untrusted clients
p.ctx.nsqd.logf("PROTOCOL(V2): [%s] Auth Failed %s", client, err)
return protocol.NewFatalClientErr(nil, "E_AUTH_FAILED", "AUTH failed")
}
if !ok {
return protocol.NewFatalClientErr(nil, "E_UNAUTHORIZED",
fmt.Sprintf("AUTH failed for %s on %q %q", cmd, topicName, channelName))
}
}
return nil
}
func (p *protocolV2) SUB(client *clientV2, params [][]byte) ([]byte, error) {
if atomic.LoadInt32(&client.State) != stateInit {
return nil, protocol.NewFatalClientErr(nil, "E_INVALID", "cannot SUB in current state")
}
if client.HeartbeatInterval <= 0 {
return nil, protocol.NewFatalClientErr(nil, "E_INVALID", "cannot SUB with heartbeats disabled")
}
if len(params) < 3 {
return nil, protocol.NewFatalClientErr(nil, "E_INVALID", "SUB insufficient number of parameters")
}
topicName := string(params[1])
if !protocol.IsValidTopicName(topicName) {
return nil, protocol.NewFatalClientErr(nil, "E_BAD_TOPIC",
fmt.Sprintf("SUB topic name %q is not valid", topicName))
}
channelName := string(params[2])
if !protocol.IsValidChannelName(channelName) {
return nil, protocol.NewFatalClientErr(nil, "E_BAD_CHANNEL",
fmt.Sprintf("SUB channel name %q is not valid", channelName))
}
if err := p.CheckAuth(client, "SUB", topicName, channelName); err != nil {
return nil, err
}
topic := p.ctx.nsqd.GetTopic(topicName)
channel := topic.GetChannel(channelName)
channel.AddClient(client.ID, client)
atomic.StoreInt32(&client.State, stateSubscribed)
client.Channel = channel
// update message pump
client.SubEventChan <- channel
return okBytes, nil
}
func (p *protocolV2) RDY(client *clientV2, params [][]byte) ([]byte, error) {
state := atomic.LoadInt32(&client.State)
if state == stateClosing {
// just ignore ready changes on a closing channel
p.ctx.nsqd.logf(
"PROTOCOL(V2): [%s] ignoring RDY after CLS in state ClientStateV2Closing",
client)
return nil, nil
}
if state != stateSubscribed {
return nil, protocol.NewFatalClientErr(nil, "E_INVALID", "cannot RDY in current state")
}
count := int64(1)
if len(params) > 1 {
b10, err := protocol.ByteToBase10(params[1])
if err != nil {
return nil, protocol.NewFatalClientErr(err, "E_INVALID",
fmt.Sprintf("RDY could not parse count %s", params[1]))
}
count = int64(b10)
}
if count < 0 || count > p.ctx.nsqd.getOpts().MaxRdyCount {
// this needs to be a fatal error otherwise clients would have
// inconsistent state
return nil, protocol.NewFatalClientErr(nil, "E_INVALID",
fmt.Sprintf("RDY count %d out of range 0-%d", count, p.ctx.nsqd.getOpts().MaxRdyCount))
}
client.SetReadyCount(count)
return nil, nil
}
func (p *protocolV2) FIN(client *clientV2, params [][]byte) ([]byte, error) {
state := atomic.LoadInt32(&client.State)
if state != stateSubscribed && state != stateClosing {
return nil, protocol.NewFatalClientErr(nil, "E_INVALID", "cannot FIN in current state")
}
if len(params) < 2 {
return nil, protocol.NewFatalClientErr(nil, "E_INVALID", "FIN insufficient number of params")
}
id, err := getMessageID(params[1])
if err != nil {
return nil, protocol.NewFatalClientErr(nil, "E_INVALID", err.Error())
}
err = client.Channel.FinishMessage(client.ID, *id)
if err != nil {
return nil, protocol.NewClientErr(err, "E_FIN_FAILED",
fmt.Sprintf("FIN %s failed %s", *id, err.Error()))
}
client.FinishedMessage()
return nil, nil
}
func (p *protocolV2) REQ(client *clientV2, params [][]byte) ([]byte, error) {
state := atomic.LoadInt32(&client.State)
if state != stateSubscribed && state != stateClosing {
return nil, protocol.NewFatalClientErr(nil, "E_INVALID", "cannot REQ in current state")
}
if len(params) < 3 {
return nil, protocol.NewFatalClientErr(nil, "E_INVALID", "REQ insufficient number of params")
}
id, err := getMessageID(params[1])
if err != nil {
return nil, protocol.NewFatalClientErr(nil, "E_INVALID", err.Error())
}
timeoutMs, err := protocol.ByteToBase10(params[2])
if err != nil {
return nil, protocol.NewFatalClientErr(err, "E_INVALID",
fmt.Sprintf("REQ could not parse timeout %s", params[2]))
}
timeoutDuration := time.Duration(timeoutMs) * time.Millisecond
if timeoutDuration < 0 || timeoutDuration > p.ctx.nsqd.getOpts().MaxReqTimeout {
return nil, protocol.NewFatalClientErr(nil, "E_INVALID",
fmt.Sprintf("REQ timeout %d out of range 0-%d", timeoutDuration, p.ctx.nsqd.getOpts().MaxReqTimeout))
}
err = client.Channel.RequeueMessage(client.ID, *id, timeoutDuration)
if err != nil {
return nil, protocol.NewClientErr(err, "E_REQ_FAILED",
fmt.Sprintf("REQ %s failed %s", *id, err.Error()))
}
client.RequeuedMessage()
return nil, nil
}
func (p *protocolV2) CLS(client *clientV2, params [][]byte) ([]byte, error) {
if atomic.LoadInt32(&client.State) != stateSubscribed {
return nil, protocol.NewFatalClientErr(nil, "E_INVALID", "cannot CLS in current state")
}
client.StartClose()
return []byte("CLOSE_WAIT"), nil
}
func (p *protocolV2) NOP(client *clientV2, params [][]byte) ([]byte, error) {
return nil, nil
}
func (p *protocolV2) PUB(client *clientV2, params [][]byte) ([]byte, error) {
var err error
if len(params) < 2 {
return nil, protocol.NewFatalClientErr(nil, "E_INVALID", "PUB insufficient number of parameters")
}
topicName := string(params[1])
if !protocol.IsValidTopicName(topicName) {
return nil, protocol.NewFatalClientErr(nil, "E_BAD_TOPIC",
fmt.Sprintf("PUB topic name %q is not valid", topicName))
}
bodyLen, err := readLen(client.Reader, client.lenSlice)
if err != nil {
return nil, protocol.NewFatalClientErr(err, "E_BAD_MESSAGE", "PUB failed to read message body size")
}
if bodyLen <= 0 {
return nil, protocol.NewFatalClientErr(nil, "E_BAD_MESSAGE",
fmt.Sprintf("PUB invalid message body size %d", bodyLen))
}
if int64(bodyLen) > p.ctx.nsqd.getOpts().MaxMsgSize {
return nil, protocol.NewFatalClientErr(nil, "E_BAD_MESSAGE",
fmt.Sprintf("PUB message too big %d > %d", bodyLen, p.ctx.nsqd.getOpts().MaxMsgSize))
}
messageBody := make([]byte, bodyLen)
_, err = io.ReadFull(client.Reader, messageBody)
if err != nil {
return nil, protocol.NewFatalClientErr(err, "E_BAD_MESSAGE", "PUB failed to read message body")
}
if err := p.CheckAuth(client, "PUB", topicName, ""); err != nil {
return nil, err
}
topic := p.ctx.nsqd.GetTopic(topicName)
msg := NewMessage(<-p.ctx.nsqd.idChan, messageBody)
err = topic.PutMessage(msg)
if err != nil {
return nil, protocol.NewFatalClientErr(err, "E_PUB_FAILED", "PUB failed "+err.Error())
}
return okBytes, nil
}
func (p *protocolV2) MPUB(client *clientV2, params [][]byte) ([]byte, error) {
var err error
if len(params) < 2 {
return nil, protocol.NewFatalClientErr(nil, "E_INVALID", "MPUB insufficient number of parameters")
}
topicName := string(params[1])
if !protocol.IsValidTopicName(topicName) {
return nil, protocol.NewFatalClientErr(nil, "E_BAD_TOPIC",
fmt.Sprintf("E_BAD_TOPIC MPUB topic name %q is not valid", topicName))
}
bodyLen, err := readLen(client.Reader, client.lenSlice)
if err != nil {
return nil, protocol.NewFatalClientErr(err, "E_BAD_BODY", "MPUB failed to read body size")
}
if bodyLen <= 0 {
return nil, protocol.NewFatalClientErr(nil, "E_BAD_BODY",
fmt.Sprintf("MPUB invalid body size %d", bodyLen))
}
if int64(bodyLen) > p.ctx.nsqd.getOpts().MaxBodySize {
return nil, protocol.NewFatalClientErr(nil, "E_BAD_BODY",
fmt.Sprintf("MPUB body too big %d > %d", bodyLen, p.ctx.nsqd.getOpts().MaxBodySize))
}
messages, err := readMPUB(client.Reader, client.lenSlice, p.ctx.nsqd.idChan,
p.ctx.nsqd.getOpts().MaxMsgSize)
if err != nil {
return nil, err
}
if err := p.CheckAuth(client, "MPUB", topicName, ""); err != nil {
return nil, err
}
topic := p.ctx.nsqd.GetTopic(topicName)
// if we've made it this far we've validated all the input,
// the only possible error is that the topic is exiting during
// this next call (and no messages will be queued in that case)
err = topic.PutMessages(messages)
if err != nil {
return nil, protocol.NewFatalClientErr(err, "E_MPUB_FAILED", "MPUB failed "+err.Error())
}
return okBytes, nil
}
func (p *protocolV2) DPUB(client *clientV2, params [][]byte) ([]byte, error) {
var err error
if len(params) < 3 {
return nil, protocol.NewFatalClientErr(nil, "E_INVALID", "DPUB insufficient number of parameters")
}
topicName := string(params[1])
if !protocol.IsValidTopicName(topicName) {
return nil, protocol.NewFatalClientErr(nil, "E_BAD_TOPIC",
fmt.Sprintf("DPUB topic name %q is not valid", topicName))
}
timeoutMs, err := protocol.ByteToBase10(params[2])
if err != nil {
return nil, protocol.NewFatalClientErr(err, "E_INVALID",
fmt.Sprintf("DPUB could not parse timeout %s", params[2]))
}
timeoutDuration := time.Duration(timeoutMs) * time.Millisecond
if timeoutDuration < 0 || timeoutDuration > p.ctx.nsqd.getOpts().MaxReqTimeout {
return nil, protocol.NewFatalClientErr(nil, "E_INVALID",
fmt.Sprintf("DPUB timeout %d out of range 0-%d",
timeoutMs, p.ctx.nsqd.getOpts().MaxReqTimeout/time.Millisecond))
}
bodyLen, err := readLen(client.Reader, client.lenSlice)
if err != nil {
return nil, protocol.NewFatalClientErr(err, "E_BAD_MESSAGE", "DPUB failed to read message body size")
}
if bodyLen <= 0 {
return nil, protocol.NewFatalClientErr(nil, "E_BAD_MESSAGE",
fmt.Sprintf("DPUB invalid message body size %d", bodyLen))
}
if int64(bodyLen) > p.ctx.nsqd.getOpts().MaxMsgSize {
return nil, protocol.NewFatalClientErr(nil, "E_BAD_MESSAGE",
fmt.Sprintf("DPUB message too big %d > %d", bodyLen, p.ctx.nsqd.getOpts().MaxMsgSize))
}
messageBody := make([]byte, bodyLen)
_, err = io.ReadFull(client.Reader, messageBody)
if err != nil {
return nil, protocol.NewFatalClientErr(err, "E_BAD_MESSAGE", "DPUB failed to read message body")
}
if err := p.CheckAuth(client, "DPUB", topicName, ""); err != nil {
return nil, err
}
topic := p.ctx.nsqd.GetTopic(topicName)
msg := NewMessage(<-p.ctx.nsqd.idChan, messageBody)
msg.deferred = timeoutDuration
err = topic.PutMessage(msg)
if err != nil {
return nil, protocol.NewFatalClientErr(err, "E_DPUB_FAILED", "DPUB failed "+err.Error())
}
return okBytes, nil
}
func (p *protocolV2) TOUCH(client *clientV2, params [][]byte) ([]byte, error) {
state := atomic.LoadInt32(&client.State)
if state != stateSubscribed && state != stateClosing {
return nil, protocol.NewFatalClientErr(nil, "E_INVALID", "cannot TOUCH in current state")
}
if len(params) < 2 {
return nil, protocol.NewFatalClientErr(nil, "E_INVALID", "TOUCH insufficient number of params")
}
id, err := getMessageID(params[1])
if err != nil {
return nil, protocol.NewFatalClientErr(nil, "E_INVALID", err.Error())
}
client.writeLock.RLock()
msgTimeout := client.MsgTimeout
client.writeLock.RUnlock()
err = client.Channel.TouchMessage(client.ID, *id, msgTimeout)
if err != nil {
return nil, protocol.NewClientErr(err, "E_TOUCH_FAILED",
fmt.Sprintf("TOUCH %s failed %s", *id, err.Error()))
}
return nil, nil
}
func readMPUB(r io.Reader, tmp []byte, idChan chan MessageID, maxMessageSize int64) ([]*Message, error) {
numMessages, err := readLen(r, tmp)
if err != nil {
return nil, protocol.NewFatalClientErr(err, "E_BAD_BODY", "MPUB failed to read message count")
}
if numMessages <= 0 {
return nil, protocol.NewFatalClientErr(err, "E_BAD_BODY",
fmt.Sprintf("MPUB invalid message count %d", numMessages))
}
messages := make([]*Message, 0, numMessages)
for i := int32(0); i < numMessages; i++ {
messageSize, err := readLen(r, tmp)
if err != nil {
return nil, protocol.NewFatalClientErr(err, "E_BAD_MESSAGE",
fmt.Sprintf("MPUB failed to read message(%d) body size", i))
}
if messageSize <= 0 {
return nil, protocol.NewFatalClientErr(nil, "E_BAD_MESSAGE",
fmt.Sprintf("MPUB invalid message(%d) body size %d", i, messageSize))
}
if int64(messageSize) > maxMessageSize {
return nil, protocol.NewFatalClientErr(nil, "E_BAD_MESSAGE",
fmt.Sprintf("MPUB message too big %d > %d", messageSize, maxMessageSize))
}
msgBody := make([]byte, messageSize)
_, err = io.ReadFull(r, msgBody)
if err != nil {
return nil, protocol.NewFatalClientErr(err, "E_BAD_MESSAGE", "MPUB failed to read message body")
}
messages = append(messages, NewMessage(<-idChan, msgBody))
}
return messages, nil
}
// validate and cast the bytes on the wire to a message ID
func getMessageID(p []byte) (*MessageID, error) {
if len(p) != MsgIDLength {
return nil, errors.New("Invalid Message ID")
}
return (*MessageID)(unsafe.Pointer(&p[0])), nil
}
func readLen(r io.Reader, tmp []byte) (int32, error) {
_, err := io.ReadFull(r, tmp)
if err != nil {
return 0, err
}
return int32(binary.BigEndian.Uint32(tmp)), nil
}
func enforceTLSPolicy(client *clientV2, p *protocolV2, command []byte) error {
if p.ctx.nsqd.getOpts().TLSRequired != TLSNotRequired && atomic.LoadInt32(&client.TLS) != 1 {
return protocol.NewFatalClientErr(nil, "E_INVALID",
fmt.Sprintf("cannot %s in current state (TLS required)", command))
}
return nil
}