go_wireproto/funcs_request.go

422 lines
9.3 KiB
Go
Raw Normal View History

2024-07-09 23:40:20 -04:00
package wireproto
import (
`bytes`
`cmp`
`fmt`
`hash/crc32`
`io`
`strings`
`github.com/google/uuid`
)
// ConnId returns a copy of the connection ID.
func (r *Request) ConnId() (conn uuid.UUID) {
conn = r.connId
return
}
// MarshalBinary renders a Request into a byte-packed format.
func (r *Request) MarshalBinary() (data []byte, err error) {
var b []byte
var msgSize int
var cksum uint32
var buf *bytes.Buffer = new(bytes.Buffer)
var msgBuf *bytes.Buffer = new(bytes.Buffer)
_ = r.Size()
for _, i := range r.RecordGroups {
msgSize += i.Size()
}
// The message "body" - we do this first so we can checksum (if not nil).
if _, err = msgBuf.Write(hdrBODYSTART); err != nil {
return
}
// Record group count
if _, err = msgBuf.Write(PackInt(len(r.RecordGroups))); err != nil {
return
}
// And size.
if _, err = msgBuf.Write(PackInt(msgSize)); err != nil {
return
}
for _, i := range r.RecordGroups {
if b, err = i.MarshalBinary(); err != nil {
return
}
if _, err = msgBuf.Write(b); err != nil {
return
}
}
if _, err = msgBuf.Write(hdrBODYEND); err != nil {
return
}
// Now the surrounding request.
// Checksum - update and serialize if not null.
if r.Checksum != nil {
cksum = crc32.ChecksumIEEE(msgBuf.Bytes())
*r.Checksum = cksum
if _, err = buf.Write(hdrCKSUM); err != nil {
return
}
if _, err = buf.Write(cksumBytes(*r.Checksum)); err != nil {
return
}
}
// Message start
if _, err = buf.Write(hdrMSGSTART); err != nil {
return
}
// Protocol version
if _, err = buf.Write(PackUint32(r.ProtocolVersion)); err != nil {
return
}
// Then copy the msgBuf in.
if _, err = msgBuf.WriteTo(buf); err != nil {
return
}
// And then the message end.
if _, err = buf.Write(hdrMSGEND); err != nil {
return
}
data = buf.Bytes()
return
}
// Model returns an indented string representation of the model.
func (r *Request) Model() (out string) {
out = r.ModelCustom(IndentChars, SeparatorChars, indentR)
return
}
func (r *Request) ModelCustom(indent, sep string, level uint) (out string) {
var maxFtr int
var size int
var sb strings.Builder
for _, rg := range r.RecordGroups {
size += rg.Size()
}
// Checksum (optional for Request)
sb.WriteString(strings.Repeat(indent, int(level)))
if r.Checksum == nil {
sb.WriteString(strings.Repeat("-", 8))
sb.WriteString(sep)
sb.WriteString("// (No Checksum Present)\n")
} else {
// HDR: CKSUM
sb.WriteString(padBytesRight(hdrCKSUM, 8))
sb.WriteString(sep)
sb.WriteString("// HDR:CKSUM\n")
// Checksum
sb.WriteString(strings.Repeat(indent, int(level)))
sb.WriteString(padBytesRight(cksumBytes(*r.Checksum), 8))
sb.WriteString(sep)
sb.WriteString(fmt.Sprintf("// Checksum Value (%d)\n", *r.Checksum))
}
// Header: MSGSTART
sb.WriteString(strings.Repeat(indent, int(level)))
sb.WriteString(padBytesRight(hdrMSGSTART, 8))
sb.WriteString(sep)
sb.WriteString("// HDR:MSGSTART\n")
// Protocol Version
sb.WriteString(strings.Repeat(indent, int(level)))
sb.WriteString(padIntRight(int(r.ProtocolVersion), 8))
sb.WriteString(sep)
sb.WriteString(fmt.Sprintf("// Protocol Version (%d)\n", r.ProtocolVersion))
// Header: BODYSTART
sb.WriteString(strings.Repeat(indent, int(level)))
sb.WriteString(padBytesRight(hdrBODYSTART, 8))
sb.WriteString(sep)
sb.WriteString("// HDR:BODYSTART\n")
// Count
sb.WriteString(strings.Repeat(indent, int(level)))
sb.WriteString(padIntRight(len(r.RecordGroups), 8))
sb.WriteString(sep)
sb.WriteString(fmt.Sprintf("// Record Group Count (%d)\n", len(r.RecordGroups)))
// Size
sb.WriteString(strings.Repeat(indent, int(level)))
sb.WriteString(padIntRight(size, 8))
sb.WriteString(sep)
sb.WriteString(fmt.Sprintf("// Record Groups Size (%d)\n", size))
// VALUES
for idx, rg := range r.RecordGroups {
sb.WriteString(strings.Repeat(indent, int(level)))
sb.WriteString(fmt.Sprintf("// Record Group %d (%d bytes)\n", idx+1, rg.Size()))
sb.WriteString(rg.ModelCustom(indent, sep, level+1))
}
// Make the footers a little more nicely aligned.
switch cmp.Compare(len(hdrBODYEND), len(hdrMSGEND)) {
case -1:
maxFtr = len(hdrMSGEND)
case 1, 0:
maxFtr = len(hdrBODYEND)
}
// Footer: BODYEND
sb.WriteString(strings.Repeat(indent, int(level)))
sb.WriteString(padBytesRight(hdrBODYEND, maxFtr))
sb.WriteString(sep)
sb.WriteString("// HDR:BODYEND\n")
// Footer: MSGEND
sb.WriteString(strings.Repeat(indent, int(level)))
sb.WriteString(padBytesRight(hdrMSGEND, maxFtr))
sb.WriteString(sep)
sb.WriteString("// HDR:MSGEND\n")
out = sb.String()
return
}
// Resolve associates children with parents.
func (r *Request) Resolve() {
for idx, i := range r.RecordGroups {
i.parent = r
i.rgIdx = idx
i.Resolve()
}
}
// SetConnID allows for setting a connection ID. This is largely just used for debugging purposes.
func (r *Request) SetConnID(connId uuid.UUID) {
r.connId = connId
for _, rg := range r.RecordGroups {
rg.connId = connId
for _, rec := range rg.Records {
rec.connId = connId
for _, kvp := range rec.Pairs {
kvp.connId = &connId
}
}
}
}
// Size returns the Request's calculated size (in bytes) and updates the size field if 0.
func (r *Request) Size() (size int) {
if r == nil {
return
}
// Checksum
if r.Checksum != nil {
size += len(hdrCKSUM)
size += CksumPackedSize
}
// Message header
size += len(hdrMSGSTART)
// Protocol version
size += PackedNumSize
// Count and Size uint32s
size += PackedNumSize * 2
// Message begin
size += len(hdrBODYSTART)
for _, p := range r.RecordGroups {
size += p.Size()
}
// Message end
size += len(hdrBODYEND)
// And closing sequence.
size += len(hdrMSGEND)
if r.common == nil {
r.common = new(common)
}
r.size = uint32(size)
return
}
// ToMap returns a slice of slice of slice of FVP maps for this Message.
func (r *Request) ToMap() (m [][][]map[string]interface{}) {
m = make([][][]map[string]interface{}, len(r.RecordGroups))
for idx, rg := range r.RecordGroups {
m[idx] = rg.ToMap()
}
return
}
// UnmarshalBinary populates a Request from packed bytes.
func (r *Request) UnmarshalBinary(data []byte) (err error) {
if data == nil || len(data) == 0 {
return
}
var b []byte
var rgCnt, bodySize int
var rgSize int
var rgBuf *bytes.Buffer
var buf *bytes.Reader = bytes.NewReader(data)
var msgBuf *bytes.Buffer = new(bytes.Buffer)
if r == nil {
*r = Request{}
}
r.common = &common{}
r.size = 0
// Check for a checksum.
b = make([]byte, len(hdrMSGSTART))
if _, err = buf.Read(b); err != nil {
return
}
if bytes.Equal(b, hdrCKSUM) {
// A checksum header was found.
b = make([]byte, CksumPackedSize)
if _, err = buf.Read(b); err != nil {
return
}
r.Checksum = new(uint32)
*r.Checksum = byteOrder.Uint32(b)
// Since we've only read the checksum, we now also have to read in the MSGSTART...
b = make([]byte, len(hdrMSGSTART))
if _, err = buf.Read(b); err != nil {
return
}
// But we don't need to do anything with it.
} else {
// We've already read MSGSTART as part of the checksum check.
r.Checksum = nil
}
// Get the protocol version.
b = make([]byte, PackedNumSize)
if _, err = buf.Read(b); err != nil {
return
}
r.ProtocolVersion = UnpackUint32(b)
// Skip over the BODYSTART (but write it to msgBuf).
if _, err = io.CopyN(msgBuf, buf, int64(len(hdrBODYSTART))); err != nil {
return
}
// Get the count of record groups
b = make([]byte, PackedNumSize)
if _, err = buf.Read(b); err != nil {
return
}
rgCnt = UnpackInt(b)
if _, err = msgBuf.Write(b); err != nil {
return
}
// And their size
b = make([]byte, PackedNumSize)
if _, err = buf.Read(b); err != nil {
return
}
bodySize = UnpackInt(b)
if _, err = msgBuf.Write(b); err != nil {
return
}
// And the record groups themselves.
if _, err = io.CopyN(msgBuf, buf, int64(bodySize+len(hdrBODYEND))); err != nil {
return
}
// Validate the checksum.
if r.Checksum != nil {
if crc32.ChecksumIEEE(msgBuf.Bytes()) != *r.Checksum {
err = ErrBadCksum
return
}
}
// Now that we've validated the checksum (if provided), we trim the msgBuf to only RGs.
// Skip over the BODYSTART, record group count, and record group size.
if _, err = msgBuf.Read(make([]byte, len(hdrBODYSTART)+(PackedNumSize*2))); err != nil {
return
}
// Then truncate.
msgBuf.Truncate(bodySize)
r.RecordGroups = make([]*RequestRecordGroup, rgCnt)
for idx := 0; idx < rgCnt; idx++ {
rgBuf = new(bytes.Buffer)
// The RG unmarshaler handles the record count, but we need to read it to discard it in msgBuf.
if _, err = io.CopyN(rgBuf, msgBuf, int64(PackedNumSize)); err != nil {
return
}
b = make([]byte, PackedNumSize)
if _, err = msgBuf.Read(b); err != nil {
return
}
if _, err = rgBuf.Write(b); err != nil {
return
}
rgSize = UnpackInt(b)
if _, err = io.CopyN(rgBuf, msgBuf, int64(rgSize)); err != nil {
return
}
r.RecordGroups[idx] = new(RequestRecordGroup)
if err = r.RecordGroups[idx].UnmarshalBinary(rgBuf.Bytes()); err != nil {
return
}
}
_ = r.Size()
return
}
// getIdx is a NOOP for Messages, but is used for Model conformance.
func (r *Request) getIdx() (idx int) {
return
}
// getRecordGroups returns the RecordGroups in this Message.
func (r *Request) getRecordGroups() (recordGroups []RecordGroup) {
recordGroups = make([]RecordGroup, len(r.RecordGroups))
for idx, rg := range r.RecordGroups {
recordGroups[idx] = rg
}
return
}