486 lines
10 KiB
Go
486 lines
10 KiB
Go
|
package wireproto
|
||
|
|
||
|
import (
|
||
|
`bytes`
|
||
|
`cmp`
|
||
|
`fmt`
|
||
|
`hash/crc32`
|
||
|
`io`
|
||
|
`strings`
|
||
|
|
||
|
`r00t2.io/goutils/multierr`
|
||
|
)
|
||
|
|
||
|
// GenChecksum (re-)generates and returns the checksum. The body that is checksummed is returned in buf.
|
||
|
func (r *Response) GenChecksum() (cksum uint32, buf *bytes.Buffer, err error) {
|
||
|
|
||
|
var b []byte
|
||
|
var size int
|
||
|
|
||
|
buf = new(bytes.Buffer)
|
||
|
|
||
|
_ = r.Size()
|
||
|
|
||
|
for _, p := range r.RecordGroups {
|
||
|
size += p.Size()
|
||
|
}
|
||
|
|
||
|
if _, err = buf.Write(hdrBODYSTART); err != nil {
|
||
|
return
|
||
|
}
|
||
|
|
||
|
if _, err = buf.Write(PackInt(len(r.RecordGroups))); err != nil {
|
||
|
return
|
||
|
}
|
||
|
if _, err = buf.Write(PackInt(size)); err != nil {
|
||
|
return
|
||
|
}
|
||
|
|
||
|
for _, rg := range r.RecordGroups {
|
||
|
if b, err = rg.MarshalBinary(); err != nil {
|
||
|
return
|
||
|
}
|
||
|
if _, err = buf.Write(b); err != nil {
|
||
|
return
|
||
|
}
|
||
|
}
|
||
|
|
||
|
if _, err = buf.Write(hdrBODYEND); err != nil {
|
||
|
return
|
||
|
}
|
||
|
|
||
|
cksum = crc32.ChecksumIEEE(buf.Bytes())
|
||
|
|
||
|
r.Checksum = cksum
|
||
|
|
||
|
return
|
||
|
}
|
||
|
|
||
|
// RespStat trawls the KVP for a field name with "error" and updates the status to reflect an error was found.
|
||
|
func (r *Response) RespStat() {
|
||
|
|
||
|
r.Status = RespStatusByteOK
|
||
|
|
||
|
for _, rg := range r.RecordGroups {
|
||
|
for _, rec := range rg.Records {
|
||
|
for _, kvp := range rec.Pairs {
|
||
|
if strings.ToLower(kvp.Name.String()) == "error" {
|
||
|
r.Status = RespStatusByteErr
|
||
|
return
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
}
|
||
|
|
||
|
// MarshalBinary renders a Response into a byte-packed format.
|
||
|
func (r *Response) MarshalBinary() (data []byte, err error) {
|
||
|
|
||
|
var b []byte
|
||
|
var msgSize int
|
||
|
var hasErr bool
|
||
|
var mErr *multierr.MultiError = multierr.NewMultiError(nil)
|
||
|
var buf *bytes.Buffer = new(bytes.Buffer)
|
||
|
var msgBuf *bytes.Buffer = new(bytes.Buffer)
|
||
|
|
||
|
_ = r.Size()
|
||
|
|
||
|
for _, i := range r.RecordGroups {
|
||
|
msgSize += i.Size()
|
||
|
}
|
||
|
|
||
|
// The message "body" - we do this first so we can checksum .
|
||
|
if _, err = msgBuf.Write(hdrBODYSTART); err != nil {
|
||
|
return
|
||
|
}
|
||
|
// Record group count
|
||
|
if _, err = msgBuf.Write(PackInt(len(r.RecordGroups))); err != nil {
|
||
|
return
|
||
|
}
|
||
|
// And size.
|
||
|
if _, err = msgBuf.Write(PackInt(msgSize)); err != nil {
|
||
|
return
|
||
|
}
|
||
|
for _, i := range r.RecordGroups {
|
||
|
if b, err = i.MarshalBinary(); err != nil {
|
||
|
mErr.AddError(err)
|
||
|
err = nil
|
||
|
hasErr = true
|
||
|
}
|
||
|
if _, err = msgBuf.Write(b); err != nil {
|
||
|
mErr.AddError(err)
|
||
|
err = nil
|
||
|
hasErr = true
|
||
|
}
|
||
|
}
|
||
|
if _, err = msgBuf.Write(hdrBODYEND); err != nil {
|
||
|
mErr.AddError(err)
|
||
|
err = nil
|
||
|
hasErr = true
|
||
|
}
|
||
|
|
||
|
// Now we write the response as a whole.
|
||
|
|
||
|
// Status
|
||
|
if r.Status == RespStatusByteOK && hasErr {
|
||
|
r.Status = RespStatusByteErr
|
||
|
}
|
||
|
if _, err = buf.Write([]byte{r.Status}); err != nil {
|
||
|
return
|
||
|
}
|
||
|
|
||
|
// Checksum -- ALWAYS present for responses!
|
||
|
if _, _, err = r.GenChecksum(); err != nil {
|
||
|
return
|
||
|
}
|
||
|
if _, err = buf.Write(hdrCKSUM); err != nil {
|
||
|
return
|
||
|
}
|
||
|
if _, err = buf.Write(cksumBytes(r.Checksum)); err != nil {
|
||
|
return
|
||
|
}
|
||
|
|
||
|
// Message start
|
||
|
if _, err = buf.Write(hdrMSGSTART); err != nil {
|
||
|
return
|
||
|
}
|
||
|
|
||
|
// Protocol version
|
||
|
if _, err = buf.Write(PackUint32(r.ProtocolVersion)); err != nil {
|
||
|
return
|
||
|
}
|
||
|
|
||
|
// Then copy the msgBuf in.
|
||
|
if _, err = msgBuf.WriteTo(buf); err != nil {
|
||
|
return
|
||
|
}
|
||
|
// And then the message end.
|
||
|
if _, err = buf.Write(hdrMSGEND); err != nil {
|
||
|
return
|
||
|
}
|
||
|
|
||
|
data = buf.Bytes()
|
||
|
|
||
|
if !mErr.IsEmpty() {
|
||
|
err = mErr
|
||
|
return
|
||
|
}
|
||
|
|
||
|
return
|
||
|
}
|
||
|
|
||
|
// Model returns an indented string representation of the model.
|
||
|
func (r *Response) Model() (out string) {
|
||
|
|
||
|
out = r.ModelCustom(IndentChars, SeparatorChars, indentR)
|
||
|
|
||
|
return
|
||
|
}
|
||
|
|
||
|
// ModelCustom is like Model with user-defined formatting.
|
||
|
func (r *Response) ModelCustom(indent, sep string, level uint) (out string) {
|
||
|
|
||
|
var maxFtr int
|
||
|
var size int
|
||
|
var sb strings.Builder
|
||
|
|
||
|
for _, rg := range r.RecordGroups {
|
||
|
size += rg.Size()
|
||
|
}
|
||
|
|
||
|
_, _, _ = r.GenChecksum()
|
||
|
|
||
|
// HDR: RESPSTART/RESPERR (RESPSTATUS)
|
||
|
sb.WriteString(strings.Repeat(indent, int(level)))
|
||
|
sb.WriteString(padBytesRight([]byte{byte(r.Status)}, 8))
|
||
|
sb.WriteString(sep)
|
||
|
switch r.Status {
|
||
|
case RespStatusByteOK:
|
||
|
sb.WriteString("// HDR:RESPSTART (Status: OK)\n")
|
||
|
case RespStatusByteErr:
|
||
|
sb.WriteString("// HDR:RESPERR (Status: Error)\n")
|
||
|
}
|
||
|
|
||
|
// HDR: CKSUM
|
||
|
sb.WriteString(strings.Repeat(indent, int(level)))
|
||
|
sb.WriteString(padBytesRight(hdrCKSUM, 8))
|
||
|
sb.WriteString(sep)
|
||
|
sb.WriteString("// HDR:CKSUM\n")
|
||
|
// Checksum
|
||
|
sb.WriteString(strings.Repeat(indent, int(level)))
|
||
|
sb.WriteString(padBytesRight(cksumBytes(r.Checksum), 8))
|
||
|
sb.WriteString(sep)
|
||
|
sb.WriteString(fmt.Sprintf("// Checksum Value (%d)\n", r.Checksum))
|
||
|
|
||
|
// Header: MSGSTART
|
||
|
sb.WriteString(strings.Repeat(indent, int(level)))
|
||
|
sb.WriteString(padBytesRight(hdrMSGSTART, 8))
|
||
|
sb.WriteString(sep)
|
||
|
sb.WriteString("// HDR:MSGSTART\n")
|
||
|
|
||
|
// Protocol Version
|
||
|
sb.WriteString(strings.Repeat(indent, int(level)))
|
||
|
sb.WriteString(padIntRight(int(r.ProtocolVersion), 8))
|
||
|
sb.WriteString(sep)
|
||
|
sb.WriteString(fmt.Sprintf("// Protocol Version (%d)\n", r.ProtocolVersion))
|
||
|
|
||
|
// Header: BODYSTART
|
||
|
sb.WriteString(strings.Repeat(indent, int(level)))
|
||
|
sb.WriteString(padBytesRight(hdrBODYSTART, 8))
|
||
|
sb.WriteString(sep)
|
||
|
sb.WriteString("// HDR:BODYSTART\n")
|
||
|
|
||
|
// Count
|
||
|
sb.WriteString(strings.Repeat(indent, int(level)))
|
||
|
sb.WriteString(padIntRight(len(r.RecordGroups), 8))
|
||
|
sb.WriteString(sep)
|
||
|
sb.WriteString(fmt.Sprintf("// Record Group Count (%d)\n", len(r.RecordGroups)))
|
||
|
// Size
|
||
|
sb.WriteString(strings.Repeat(indent, int(level)))
|
||
|
sb.WriteString(padIntRight(size, 8))
|
||
|
sb.WriteString(sep)
|
||
|
sb.WriteString(fmt.Sprintf("// Record Groups Size (%d)\n", size))
|
||
|
|
||
|
// VALUES
|
||
|
for idx, rg := range r.RecordGroups {
|
||
|
sb.WriteString(fmt.Sprintf("// Record Group %d (%d)\n", idx+1, rg.Size()))
|
||
|
sb.WriteString(rg.ModelCustom(indent, sep, level+1))
|
||
|
}
|
||
|
|
||
|
// Make the footers a little more nicely aligned.
|
||
|
switch cmp.Compare(len(hdrBODYEND), len(hdrMSGEND)) {
|
||
|
case -1:
|
||
|
maxFtr = len(hdrMSGEND)
|
||
|
case 1, 0:
|
||
|
maxFtr = len(hdrBODYEND)
|
||
|
}
|
||
|
|
||
|
// Footer: BODYEND
|
||
|
sb.WriteString(strings.Repeat(indent, int(level)))
|
||
|
sb.WriteString(padBytesRight(hdrBODYEND, maxFtr))
|
||
|
sb.WriteString(sep)
|
||
|
sb.WriteString("// HDR:BODYEND\n")
|
||
|
|
||
|
// Footer: MSGEND
|
||
|
sb.WriteString(strings.Repeat(indent, int(level)))
|
||
|
sb.WriteString(padBytesRight(hdrMSGEND, maxFtr))
|
||
|
sb.WriteString(sep)
|
||
|
sb.WriteString("// HDR:MSGEND\n")
|
||
|
|
||
|
out = sb.String()
|
||
|
|
||
|
return
|
||
|
}
|
||
|
|
||
|
// Resolve associates children with parents.
|
||
|
func (r *Response) Resolve() {
|
||
|
for idx, i := range r.RecordGroups {
|
||
|
i.parent = r
|
||
|
i.rgIdx = idx
|
||
|
i.Resolve()
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// Size returns the Response's calculated size (in bytes) and updates the size field if 0.
|
||
|
func (r *Response) Size() (size int) {
|
||
|
|
||
|
if r == nil {
|
||
|
return
|
||
|
}
|
||
|
|
||
|
// Response Status
|
||
|
size += 1
|
||
|
|
||
|
// Checksum
|
||
|
size += len(hdrCKSUM)
|
||
|
size += CksumPackedSize
|
||
|
|
||
|
// Message header
|
||
|
size += len(hdrMSGSTART)
|
||
|
|
||
|
// Protocol version
|
||
|
size += PackedNumSize
|
||
|
|
||
|
// Count and Size uint32s
|
||
|
size += PackedNumSize * 2
|
||
|
|
||
|
// Message begin
|
||
|
size += len(hdrBODYSTART)
|
||
|
|
||
|
for _, p := range r.RecordGroups {
|
||
|
size += p.Size()
|
||
|
}
|
||
|
|
||
|
// Message end
|
||
|
size += len(hdrBODYEND)
|
||
|
|
||
|
// And closing sequence.
|
||
|
size += len(hdrMSGEND)
|
||
|
|
||
|
if r.common == nil || r.size == 0 {
|
||
|
r.common = new(common)
|
||
|
}
|
||
|
|
||
|
r.size = uint32(size)
|
||
|
|
||
|
return
|
||
|
}
|
||
|
|
||
|
// ToMap returns a slice of slice of slice of FVP maps for this Message.
|
||
|
func (r *Response) ToMap() (m [][][]map[string]interface{}) {
|
||
|
|
||
|
m = make([][][]map[string]interface{}, len(r.RecordGroups))
|
||
|
for idx, rg := range r.RecordGroups {
|
||
|
m[idx] = rg.ToMap()
|
||
|
}
|
||
|
|
||
|
return
|
||
|
}
|
||
|
|
||
|
// UnmarshalBinary populates a Response from packed bytes.
|
||
|
func (r *Response) UnmarshalBinary(data []byte) (err error) {
|
||
|
|
||
|
if data == nil || len(data) == 0 {
|
||
|
return
|
||
|
}
|
||
|
|
||
|
var b []byte
|
||
|
var rgCnt, bodySize int
|
||
|
var rgSize int
|
||
|
var rgBuf *bytes.Buffer
|
||
|
var buf *bytes.Reader = bytes.NewReader(data)
|
||
|
var msgBuf *bytes.Buffer = new(bytes.Buffer)
|
||
|
|
||
|
if r == nil {
|
||
|
*r = Response{}
|
||
|
}
|
||
|
r.common = new(common)
|
||
|
r.size = 0
|
||
|
|
||
|
// Get the status.
|
||
|
if r.Status, err = buf.ReadByte(); err != nil {
|
||
|
return
|
||
|
}
|
||
|
|
||
|
// And the checksum -- responses *always* have checksums per spec!
|
||
|
// Toss the checksum header (after confirming).
|
||
|
b = make([]byte, len(hdrCKSUM))
|
||
|
if _, err = buf.Read(b); err != nil {
|
||
|
return
|
||
|
}
|
||
|
if !bytes.Equal(b, hdrCKSUM) {
|
||
|
err = ErrBadHdr
|
||
|
return
|
||
|
}
|
||
|
// And get the checksum.
|
||
|
b = make([]byte, CksumPackedSize)
|
||
|
if _, err = buf.Read(b); err != nil {
|
||
|
return
|
||
|
}
|
||
|
r.Checksum = UnpackUint32(b)
|
||
|
|
||
|
// Read (and toss) the message start header.
|
||
|
if _, err = buf.Read(make([]byte, len(hdrMSGSTART))); err != nil {
|
||
|
return
|
||
|
}
|
||
|
|
||
|
// Get the protocol version.
|
||
|
b = make([]byte, PackedNumSize)
|
||
|
if _, err = buf.Read(b); err != nil {
|
||
|
return
|
||
|
}
|
||
|
r.ProtocolVersion = UnpackUint32(b)
|
||
|
|
||
|
// Skip over the BODYSTART (but write it to msgBuf).
|
||
|
if _, err = io.CopyN(msgBuf, buf, int64(len(hdrBODYSTART))); err != nil {
|
||
|
return
|
||
|
}
|
||
|
// Get the count of record groups
|
||
|
b = make([]byte, PackedNumSize)
|
||
|
if _, err = buf.Read(b); err != nil {
|
||
|
return
|
||
|
}
|
||
|
rgCnt = UnpackInt(b)
|
||
|
if _, err = msgBuf.Write(b); err != nil {
|
||
|
return
|
||
|
}
|
||
|
// Get the size of record groups
|
||
|
b = make([]byte, PackedNumSize)
|
||
|
if _, err = buf.Read(b); err != nil {
|
||
|
return
|
||
|
}
|
||
|
bodySize = UnpackInt(b)
|
||
|
if _, err = msgBuf.Write(b); err != nil {
|
||
|
return
|
||
|
}
|
||
|
|
||
|
// And the record groups themselves.
|
||
|
if _, err = io.CopyN(msgBuf, buf, int64(bodySize+len(hdrBODYEND))); err != nil {
|
||
|
return
|
||
|
}
|
||
|
|
||
|
// Now validate the checksum before continuing.
|
||
|
if crc32.ChecksumIEEE(msgBuf.Bytes()) != r.Checksum {
|
||
|
err = ErrBadCksum
|
||
|
return
|
||
|
}
|
||
|
|
||
|
// Now that we've validated the checksum, we trim the msgBuf to only RGs.
|
||
|
// Skip over the BODYSTART, record group count, and record group size.
|
||
|
if _, err = msgBuf.Read(make([]byte, len(hdrBODYSTART)+(PackedNumSize*2))); err != nil {
|
||
|
return
|
||
|
}
|
||
|
// Then truncate.
|
||
|
msgBuf.Truncate(bodySize)
|
||
|
|
||
|
r.RecordGroups = make([]*ResponseRecordGroup, rgCnt)
|
||
|
|
||
|
for idx := 0; idx < rgCnt; idx++ {
|
||
|
rgBuf = new(bytes.Buffer)
|
||
|
|
||
|
// The RG unmarshaler handles the record count, but we need to read it into msgBuf.
|
||
|
if _, err = io.CopyN(rgBuf, msgBuf, int64(PackedNumSize)); err != nil {
|
||
|
return
|
||
|
}
|
||
|
|
||
|
b = make([]byte, PackedNumSize)
|
||
|
if _, err = msgBuf.Read(b); err != nil {
|
||
|
return
|
||
|
}
|
||
|
if _, err = rgBuf.Write(b); err != nil {
|
||
|
return
|
||
|
}
|
||
|
rgSize = UnpackInt(b)
|
||
|
|
||
|
if _, err = io.CopyN(rgBuf, msgBuf, int64(rgSize)); err != nil {
|
||
|
return
|
||
|
}
|
||
|
|
||
|
r.RecordGroups[idx] = new(ResponseRecordGroup)
|
||
|
if err = r.RecordGroups[idx].UnmarshalBinary(rgBuf.Bytes()); err != nil {
|
||
|
return
|
||
|
}
|
||
|
}
|
||
|
|
||
|
_ = r.Size()
|
||
|
|
||
|
return
|
||
|
}
|
||
|
|
||
|
// getIdx is a NOOP for Messages, but is used for Model conformance.
|
||
|
func (r *Response) getIdx() (idx int) {
|
||
|
return
|
||
|
}
|
||
|
|
||
|
// getRecordGroups returns the RecordGroups in this Message.
|
||
|
func (r *Response) getRecordGroups() (recordGroups []RecordGroup) {
|
||
|
|
||
|
recordGroups = make([]RecordGroup, len(r.RecordGroups))
|
||
|
for idx, rg := range r.RecordGroups {
|
||
|
recordGroups[idx] = rg
|
||
|
}
|
||
|
|
||
|
return
|
||
|
}
|