d4bb259b83
Initial release.
422 lines
9.3 KiB
Go
422 lines
9.3 KiB
Go
package wireproto
|
|
|
|
import (
|
|
`bytes`
|
|
`cmp`
|
|
`fmt`
|
|
`hash/crc32`
|
|
`io`
|
|
`strings`
|
|
|
|
`github.com/google/uuid`
|
|
)
|
|
|
|
// ConnId returns a copy of the connection ID.
|
|
func (r *Request) ConnId() (conn uuid.UUID) {
|
|
|
|
conn = r.connId
|
|
|
|
return
|
|
}
|
|
|
|
// MarshalBinary renders a Request into a byte-packed format.
|
|
func (r *Request) MarshalBinary() (data []byte, err error) {
|
|
|
|
var b []byte
|
|
var msgSize int
|
|
var cksum uint32
|
|
var buf *bytes.Buffer = new(bytes.Buffer)
|
|
var msgBuf *bytes.Buffer = new(bytes.Buffer)
|
|
|
|
_ = r.Size()
|
|
|
|
for _, i := range r.RecordGroups {
|
|
msgSize += i.Size()
|
|
}
|
|
|
|
// The message "body" - we do this first so we can checksum (if not nil).
|
|
if _, err = msgBuf.Write(hdrBODYSTART); err != nil {
|
|
return
|
|
}
|
|
// Record group count
|
|
if _, err = msgBuf.Write(PackInt(len(r.RecordGroups))); err != nil {
|
|
return
|
|
}
|
|
// And size.
|
|
if _, err = msgBuf.Write(PackInt(msgSize)); err != nil {
|
|
return
|
|
}
|
|
for _, i := range r.RecordGroups {
|
|
if b, err = i.MarshalBinary(); err != nil {
|
|
return
|
|
}
|
|
if _, err = msgBuf.Write(b); err != nil {
|
|
return
|
|
}
|
|
}
|
|
if _, err = msgBuf.Write(hdrBODYEND); err != nil {
|
|
return
|
|
}
|
|
|
|
// Now the surrounding request.
|
|
|
|
// Checksum - update and serialize if not null.
|
|
if r.Checksum != nil {
|
|
cksum = crc32.ChecksumIEEE(msgBuf.Bytes())
|
|
*r.Checksum = cksum
|
|
if _, err = buf.Write(hdrCKSUM); err != nil {
|
|
return
|
|
}
|
|
if _, err = buf.Write(cksumBytes(*r.Checksum)); err != nil {
|
|
return
|
|
}
|
|
}
|
|
|
|
// Message start
|
|
if _, err = buf.Write(hdrMSGSTART); err != nil {
|
|
return
|
|
}
|
|
|
|
// Protocol version
|
|
if _, err = buf.Write(PackUint32(r.ProtocolVersion)); err != nil {
|
|
return
|
|
}
|
|
|
|
// Then copy the msgBuf in.
|
|
if _, err = msgBuf.WriteTo(buf); err != nil {
|
|
return
|
|
}
|
|
// And then the message end.
|
|
if _, err = buf.Write(hdrMSGEND); err != nil {
|
|
return
|
|
}
|
|
|
|
data = buf.Bytes()
|
|
|
|
return
|
|
}
|
|
|
|
// Model returns an indented string representation of the model.
|
|
func (r *Request) Model() (out string) {
|
|
|
|
out = r.ModelCustom(IndentChars, SeparatorChars, indentR)
|
|
|
|
return
|
|
}
|
|
|
|
func (r *Request) ModelCustom(indent, sep string, level uint) (out string) {
|
|
|
|
var maxFtr int
|
|
var size int
|
|
var sb strings.Builder
|
|
|
|
for _, rg := range r.RecordGroups {
|
|
size += rg.Size()
|
|
}
|
|
|
|
// Checksum (optional for Request)
|
|
sb.WriteString(strings.Repeat(indent, int(level)))
|
|
if r.Checksum == nil {
|
|
sb.WriteString(strings.Repeat("-", 8))
|
|
sb.WriteString(sep)
|
|
sb.WriteString("// (No Checksum Present)\n")
|
|
} else {
|
|
// HDR: CKSUM
|
|
sb.WriteString(padBytesRight(hdrCKSUM, 8))
|
|
sb.WriteString(sep)
|
|
sb.WriteString("// HDR:CKSUM\n")
|
|
// Checksum
|
|
sb.WriteString(strings.Repeat(indent, int(level)))
|
|
sb.WriteString(padBytesRight(cksumBytes(*r.Checksum), 8))
|
|
sb.WriteString(sep)
|
|
sb.WriteString(fmt.Sprintf("// Checksum Value (%d)\n", *r.Checksum))
|
|
}
|
|
|
|
// Header: MSGSTART
|
|
sb.WriteString(strings.Repeat(indent, int(level)))
|
|
sb.WriteString(padBytesRight(hdrMSGSTART, 8))
|
|
sb.WriteString(sep)
|
|
sb.WriteString("// HDR:MSGSTART\n")
|
|
|
|
// Protocol Version
|
|
sb.WriteString(strings.Repeat(indent, int(level)))
|
|
sb.WriteString(padIntRight(int(r.ProtocolVersion), 8))
|
|
sb.WriteString(sep)
|
|
sb.WriteString(fmt.Sprintf("// Protocol Version (%d)\n", r.ProtocolVersion))
|
|
|
|
// Header: BODYSTART
|
|
sb.WriteString(strings.Repeat(indent, int(level)))
|
|
sb.WriteString(padBytesRight(hdrBODYSTART, 8))
|
|
sb.WriteString(sep)
|
|
sb.WriteString("// HDR:BODYSTART\n")
|
|
|
|
// Count
|
|
sb.WriteString(strings.Repeat(indent, int(level)))
|
|
sb.WriteString(padIntRight(len(r.RecordGroups), 8))
|
|
sb.WriteString(sep)
|
|
sb.WriteString(fmt.Sprintf("// Record Group Count (%d)\n", len(r.RecordGroups)))
|
|
// Size
|
|
sb.WriteString(strings.Repeat(indent, int(level)))
|
|
sb.WriteString(padIntRight(size, 8))
|
|
sb.WriteString(sep)
|
|
sb.WriteString(fmt.Sprintf("// Record Groups Size (%d)\n", size))
|
|
|
|
// VALUES
|
|
for idx, rg := range r.RecordGroups {
|
|
sb.WriteString(strings.Repeat(indent, int(level)))
|
|
sb.WriteString(fmt.Sprintf("// Record Group %d (%d bytes)\n", idx+1, rg.Size()))
|
|
sb.WriteString(rg.ModelCustom(indent, sep, level+1))
|
|
}
|
|
|
|
// Make the footers a little more nicely aligned.
|
|
switch cmp.Compare(len(hdrBODYEND), len(hdrMSGEND)) {
|
|
case -1:
|
|
maxFtr = len(hdrMSGEND)
|
|
case 1, 0:
|
|
maxFtr = len(hdrBODYEND)
|
|
}
|
|
|
|
// Footer: BODYEND
|
|
sb.WriteString(strings.Repeat(indent, int(level)))
|
|
sb.WriteString(padBytesRight(hdrBODYEND, maxFtr))
|
|
sb.WriteString(sep)
|
|
sb.WriteString("// HDR:BODYEND\n")
|
|
|
|
// Footer: MSGEND
|
|
sb.WriteString(strings.Repeat(indent, int(level)))
|
|
sb.WriteString(padBytesRight(hdrMSGEND, maxFtr))
|
|
sb.WriteString(sep)
|
|
sb.WriteString("// HDR:MSGEND\n")
|
|
|
|
out = sb.String()
|
|
|
|
return
|
|
}
|
|
|
|
// Resolve associates children with parents.
|
|
func (r *Request) Resolve() {
|
|
for idx, i := range r.RecordGroups {
|
|
i.parent = r
|
|
i.rgIdx = idx
|
|
i.Resolve()
|
|
}
|
|
}
|
|
|
|
// SetConnID allows for setting a connection ID. This is largely just used for debugging purposes.
|
|
func (r *Request) SetConnID(connId uuid.UUID) {
|
|
|
|
r.connId = connId
|
|
|
|
for _, rg := range r.RecordGroups {
|
|
rg.connId = connId
|
|
for _, rec := range rg.Records {
|
|
rec.connId = connId
|
|
for _, kvp := range rec.Pairs {
|
|
kvp.connId = &connId
|
|
}
|
|
}
|
|
}
|
|
|
|
}
|
|
|
|
// Size returns the Request's calculated size (in bytes) and updates the size field if 0.
|
|
func (r *Request) Size() (size int) {
|
|
|
|
if r == nil {
|
|
return
|
|
}
|
|
|
|
// Checksum
|
|
if r.Checksum != nil {
|
|
size += len(hdrCKSUM)
|
|
size += CksumPackedSize
|
|
}
|
|
|
|
// Message header
|
|
size += len(hdrMSGSTART)
|
|
|
|
// Protocol version
|
|
size += PackedNumSize
|
|
|
|
// Count and Size uint32s
|
|
size += PackedNumSize * 2
|
|
|
|
// Message begin
|
|
size += len(hdrBODYSTART)
|
|
|
|
for _, p := range r.RecordGroups {
|
|
size += p.Size()
|
|
}
|
|
|
|
// Message end
|
|
size += len(hdrBODYEND)
|
|
|
|
// And closing sequence.
|
|
size += len(hdrMSGEND)
|
|
|
|
if r.common == nil {
|
|
r.common = new(common)
|
|
}
|
|
|
|
r.size = uint32(size)
|
|
|
|
return
|
|
}
|
|
|
|
// ToMap returns a slice of slice of slice of FVP maps for this Message.
|
|
func (r *Request) ToMap() (m [][][]map[string]interface{}) {
|
|
|
|
m = make([][][]map[string]interface{}, len(r.RecordGroups))
|
|
for idx, rg := range r.RecordGroups {
|
|
m[idx] = rg.ToMap()
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
// UnmarshalBinary populates a Request from packed bytes.
|
|
func (r *Request) UnmarshalBinary(data []byte) (err error) {
|
|
|
|
if data == nil || len(data) == 0 {
|
|
return
|
|
}
|
|
|
|
var b []byte
|
|
var rgCnt, bodySize int
|
|
var rgSize int
|
|
var rgBuf *bytes.Buffer
|
|
var buf *bytes.Reader = bytes.NewReader(data)
|
|
var msgBuf *bytes.Buffer = new(bytes.Buffer)
|
|
|
|
if r == nil {
|
|
*r = Request{}
|
|
}
|
|
r.common = &common{}
|
|
r.size = 0
|
|
|
|
// Check for a checksum.
|
|
b = make([]byte, len(hdrMSGSTART))
|
|
if _, err = buf.Read(b); err != nil {
|
|
return
|
|
}
|
|
if bytes.Equal(b, hdrCKSUM) {
|
|
// A checksum header was found.
|
|
b = make([]byte, CksumPackedSize)
|
|
if _, err = buf.Read(b); err != nil {
|
|
return
|
|
}
|
|
r.Checksum = new(uint32)
|
|
*r.Checksum = byteOrder.Uint32(b)
|
|
// Since we've only read the checksum, we now also have to read in the MSGSTART...
|
|
b = make([]byte, len(hdrMSGSTART))
|
|
if _, err = buf.Read(b); err != nil {
|
|
return
|
|
}
|
|
// But we don't need to do anything with it.
|
|
} else {
|
|
// We've already read MSGSTART as part of the checksum check.
|
|
r.Checksum = nil
|
|
}
|
|
|
|
// Get the protocol version.
|
|
b = make([]byte, PackedNumSize)
|
|
if _, err = buf.Read(b); err != nil {
|
|
return
|
|
}
|
|
r.ProtocolVersion = UnpackUint32(b)
|
|
|
|
// Skip over the BODYSTART (but write it to msgBuf).
|
|
if _, err = io.CopyN(msgBuf, buf, int64(len(hdrBODYSTART))); err != nil {
|
|
return
|
|
}
|
|
|
|
// Get the count of record groups
|
|
b = make([]byte, PackedNumSize)
|
|
if _, err = buf.Read(b); err != nil {
|
|
return
|
|
}
|
|
rgCnt = UnpackInt(b)
|
|
if _, err = msgBuf.Write(b); err != nil {
|
|
return
|
|
}
|
|
// And their size
|
|
b = make([]byte, PackedNumSize)
|
|
if _, err = buf.Read(b); err != nil {
|
|
return
|
|
}
|
|
bodySize = UnpackInt(b)
|
|
if _, err = msgBuf.Write(b); err != nil {
|
|
return
|
|
}
|
|
|
|
// And the record groups themselves.
|
|
if _, err = io.CopyN(msgBuf, buf, int64(bodySize+len(hdrBODYEND))); err != nil {
|
|
return
|
|
}
|
|
|
|
// Validate the checksum.
|
|
if r.Checksum != nil {
|
|
if crc32.ChecksumIEEE(msgBuf.Bytes()) != *r.Checksum {
|
|
err = ErrBadCksum
|
|
return
|
|
}
|
|
}
|
|
|
|
// Now that we've validated the checksum (if provided), we trim the msgBuf to only RGs.
|
|
// Skip over the BODYSTART, record group count, and record group size.
|
|
if _, err = msgBuf.Read(make([]byte, len(hdrBODYSTART)+(PackedNumSize*2))); err != nil {
|
|
return
|
|
}
|
|
// Then truncate.
|
|
msgBuf.Truncate(bodySize)
|
|
|
|
r.RecordGroups = make([]*RequestRecordGroup, rgCnt)
|
|
|
|
for idx := 0; idx < rgCnt; idx++ {
|
|
rgBuf = new(bytes.Buffer)
|
|
|
|
// The RG unmarshaler handles the record count, but we need to read it to discard it in msgBuf.
|
|
if _, err = io.CopyN(rgBuf, msgBuf, int64(PackedNumSize)); err != nil {
|
|
return
|
|
}
|
|
|
|
b = make([]byte, PackedNumSize)
|
|
if _, err = msgBuf.Read(b); err != nil {
|
|
return
|
|
}
|
|
if _, err = rgBuf.Write(b); err != nil {
|
|
return
|
|
}
|
|
rgSize = UnpackInt(b)
|
|
|
|
if _, err = io.CopyN(rgBuf, msgBuf, int64(rgSize)); err != nil {
|
|
return
|
|
}
|
|
|
|
r.RecordGroups[idx] = new(RequestRecordGroup)
|
|
if err = r.RecordGroups[idx].UnmarshalBinary(rgBuf.Bytes()); err != nil {
|
|
return
|
|
}
|
|
}
|
|
|
|
_ = r.Size()
|
|
|
|
return
|
|
}
|
|
|
|
// getIdx is a NOOP for Messages, but is used for Model conformance.
|
|
func (r *Request) getIdx() (idx int) {
|
|
return
|
|
}
|
|
|
|
// getRecordGroups returns the RecordGroups in this Message.
|
|
func (r *Request) getRecordGroups() (recordGroups []RecordGroup) {
|
|
|
|
recordGroups = make([]RecordGroup, len(r.RecordGroups))
|
|
for idx, rg := range r.RecordGroups {
|
|
recordGroups[idx] = rg
|
|
}
|
|
|
|
return
|
|
}
|