go_wireproto/funcs_response.go
brent saner 9b39811206
v1.0.1
FIX:
* Cleaned up some documentation
2024-07-10 00:40:12 -04:00

486 lines
10 KiB
Go

package wireproto
import (
`bytes`
`cmp`
`fmt`
`hash/crc32`
`io`
`strings`
`r00t2.io/goutils/multierr`
)
// GenChecksum (re-)generates and returns the checksum. The body that is checksummed is returned in buf.
func (r *Response) GenChecksum() (cksum uint32, buf *bytes.Buffer, err error) {
var b []byte
var size int
buf = new(bytes.Buffer)
_ = r.Size()
for _, p := range r.RecordGroups {
size += p.Size()
}
if _, err = buf.Write(hdrBODYSTART); err != nil {
return
}
if _, err = buf.Write(PackInt(len(r.RecordGroups))); err != nil {
return
}
if _, err = buf.Write(PackInt(size)); err != nil {
return
}
for _, rg := range r.RecordGroups {
if b, err = rg.MarshalBinary(); err != nil {
return
}
if _, err = buf.Write(b); err != nil {
return
}
}
if _, err = buf.Write(hdrBODYEND); err != nil {
return
}
cksum = crc32.ChecksumIEEE(buf.Bytes())
r.Checksum = cksum
return
}
// RespStat trawls the KVP for a field name with "error" and updates the status to reflect an error was found.
func (r *Response) RespStat() {
r.Status = RespStatusByteOK
for _, rg := range r.RecordGroups {
for _, rec := range rg.Records {
for _, kvp := range rec.Pairs {
if strings.ToLower(kvp.Name.String()) == "error" {
r.Status = RespStatusByteErr
return
}
}
}
}
}
// MarshalBinary renders a Response into a byte-packed format.
func (r *Response) MarshalBinary() (data []byte, err error) {
var b []byte
var msgSize int
var hasErr bool
var mErr *multierr.MultiError = multierr.NewMultiError(nil)
var buf *bytes.Buffer = new(bytes.Buffer)
var msgBuf *bytes.Buffer = new(bytes.Buffer)
_ = r.Size()
for _, i := range r.RecordGroups {
msgSize += i.Size()
}
// The message "body" - this is done first so it can be checksummed.
if _, err = msgBuf.Write(hdrBODYSTART); err != nil {
return
}
// Record group count
if _, err = msgBuf.Write(PackInt(len(r.RecordGroups))); err != nil {
return
}
// And size.
if _, err = msgBuf.Write(PackInt(msgSize)); err != nil {
return
}
for _, i := range r.RecordGroups {
if b, err = i.MarshalBinary(); err != nil {
mErr.AddError(err)
err = nil
hasErr = true
}
if _, err = msgBuf.Write(b); err != nil {
mErr.AddError(err)
err = nil
hasErr = true
}
}
if _, err = msgBuf.Write(hdrBODYEND); err != nil {
mErr.AddError(err)
err = nil
hasErr = true
}
// Now write the response as a whole.
// Status
if r.Status == RespStatusByteOK && hasErr {
r.Status = RespStatusByteErr
}
if _, err = buf.Write([]byte{r.Status}); err != nil {
return
}
// Checksum; ALWAYS present for responses!
if _, _, err = r.GenChecksum(); err != nil {
return
}
if _, err = buf.Write(hdrCKSUM); err != nil {
return
}
if _, err = buf.Write(cksumBytes(r.Checksum)); err != nil {
return
}
// Message start
if _, err = buf.Write(hdrMSGSTART); err != nil {
return
}
// Protocol version
if _, err = buf.Write(PackUint32(r.ProtocolVersion)); err != nil {
return
}
// Then copy the msgBuf in.
if _, err = msgBuf.WriteTo(buf); err != nil {
return
}
// And then the message end.
if _, err = buf.Write(hdrMSGEND); err != nil {
return
}
data = buf.Bytes()
if !mErr.IsEmpty() {
err = mErr
return
}
return
}
// Model returns an indented string representation of the model.
func (r *Response) Model() (out string) {
out = r.ModelCustom(IndentChars, SeparatorChars, indentR)
return
}
// ModelCustom is like Model with user-defined formatting.
func (r *Response) ModelCustom(indent, sep string, level uint) (out string) {
var maxFtr int
var size int
var sb strings.Builder
for _, rg := range r.RecordGroups {
size += rg.Size()
}
_, _, _ = r.GenChecksum()
// HDR: RESPSTART/RESPERR (RESPSTATUS)
sb.WriteString(strings.Repeat(indent, int(level)))
sb.WriteString(padBytesRight([]byte{byte(r.Status)}, 8))
sb.WriteString(sep)
switch r.Status {
case RespStatusByteOK:
sb.WriteString("// HDR:RESPSTART (Status: OK)\n")
case RespStatusByteErr:
sb.WriteString("// HDR:RESPERR (Status: Error)\n")
}
// HDR: CKSUM
sb.WriteString(strings.Repeat(indent, int(level)))
sb.WriteString(padBytesRight(hdrCKSUM, 8))
sb.WriteString(sep)
sb.WriteString("// HDR:CKSUM\n")
// Checksum
sb.WriteString(strings.Repeat(indent, int(level)))
sb.WriteString(padBytesRight(cksumBytes(r.Checksum), 8))
sb.WriteString(sep)
sb.WriteString(fmt.Sprintf("// Checksum Value (%d)\n", r.Checksum))
// Header: MSGSTART
sb.WriteString(strings.Repeat(indent, int(level)))
sb.WriteString(padBytesRight(hdrMSGSTART, 8))
sb.WriteString(sep)
sb.WriteString("// HDR:MSGSTART\n")
// Protocol Version
sb.WriteString(strings.Repeat(indent, int(level)))
sb.WriteString(padIntRight(int(r.ProtocolVersion), 8))
sb.WriteString(sep)
sb.WriteString(fmt.Sprintf("// Protocol Version (%d)\n", r.ProtocolVersion))
// Header: BODYSTART
sb.WriteString(strings.Repeat(indent, int(level)))
sb.WriteString(padBytesRight(hdrBODYSTART, 8))
sb.WriteString(sep)
sb.WriteString("// HDR:BODYSTART\n")
// Count
sb.WriteString(strings.Repeat(indent, int(level)))
sb.WriteString(padIntRight(len(r.RecordGroups), 8))
sb.WriteString(sep)
sb.WriteString(fmt.Sprintf("// Record Group Count (%d)\n", len(r.RecordGroups)))
// Size
sb.WriteString(strings.Repeat(indent, int(level)))
sb.WriteString(padIntRight(size, 8))
sb.WriteString(sep)
sb.WriteString(fmt.Sprintf("// Record Groups Size (%d)\n", size))
// VALUES
for idx, rg := range r.RecordGroups {
sb.WriteString(fmt.Sprintf("// Record Group %d (%d)\n", idx+1, rg.Size()))
sb.WriteString(rg.ModelCustom(indent, sep, level+1))
}
// Make the footers a little more nicely aligned.
switch cmp.Compare(len(hdrBODYEND), len(hdrMSGEND)) {
case -1:
maxFtr = len(hdrMSGEND)
case 1, 0:
maxFtr = len(hdrBODYEND)
}
// Footer: BODYEND
sb.WriteString(strings.Repeat(indent, int(level)))
sb.WriteString(padBytesRight(hdrBODYEND, maxFtr))
sb.WriteString(sep)
sb.WriteString("// HDR:BODYEND\n")
// Footer: MSGEND
sb.WriteString(strings.Repeat(indent, int(level)))
sb.WriteString(padBytesRight(hdrMSGEND, maxFtr))
sb.WriteString(sep)
sb.WriteString("// HDR:MSGEND\n")
out = sb.String()
return
}
// Resolve associates children with parents.
func (r *Response) Resolve() {
for idx, i := range r.RecordGroups {
i.parent = r
i.rgIdx = idx
i.Resolve()
}
}
// Size returns the Response's calculated size (in bytes) and updates the size field if 0.
func (r *Response) Size() (size int) {
if r == nil {
return
}
// Response Status
size += 1
// Checksum
size += len(hdrCKSUM)
size += CksumPackedSize
// Message header
size += len(hdrMSGSTART)
// Protocol version
size += PackedNumSize
// Count and Size uint32s
size += PackedNumSize * 2
// Message begin
size += len(hdrBODYSTART)
for _, p := range r.RecordGroups {
size += p.Size()
}
// Message end
size += len(hdrBODYEND)
// And closing sequence.
size += len(hdrMSGEND)
if r.common == nil || r.size == 0 {
r.common = new(common)
}
r.size = uint32(size)
return
}
// ToMap returns a slice of slice of slice of FVP maps for this Message.
func (r *Response) ToMap() (m [][][]map[string]interface{}) {
m = make([][][]map[string]interface{}, len(r.RecordGroups))
for idx, rg := range r.RecordGroups {
m[idx] = rg.ToMap()
}
return
}
// UnmarshalBinary populates a Response from packed bytes.
func (r *Response) UnmarshalBinary(data []byte) (err error) {
if data == nil || len(data) == 0 {
return
}
var b []byte
var rgCnt, bodySize int
var rgSize int
var rgBuf *bytes.Buffer
var buf *bytes.Reader = bytes.NewReader(data)
var msgBuf *bytes.Buffer = new(bytes.Buffer)
if r == nil {
*r = Response{}
}
r.common = new(common)
r.size = 0
// Get the status.
if r.Status, err = buf.ReadByte(); err != nil {
return
}
// And the checksum -- responses *always* have checksums per spec!
// Toss the checksum header (after confirming).
b = make([]byte, len(hdrCKSUM))
if _, err = buf.Read(b); err != nil {
return
}
if !bytes.Equal(b, hdrCKSUM) {
err = ErrBadHdr
return
}
// And get the checksum.
b = make([]byte, CksumPackedSize)
if _, err = buf.Read(b); err != nil {
return
}
r.Checksum = UnpackUint32(b)
// Read (and toss) the message start header.
if _, err = buf.Read(make([]byte, len(hdrMSGSTART))); err != nil {
return
}
// Get the protocol version.
b = make([]byte, PackedNumSize)
if _, err = buf.Read(b); err != nil {
return
}
r.ProtocolVersion = UnpackUint32(b)
// Skip over the BODYSTART (but write it to msgBuf).
if _, err = io.CopyN(msgBuf, buf, int64(len(hdrBODYSTART))); err != nil {
return
}
// Get the count of record groups
b = make([]byte, PackedNumSize)
if _, err = buf.Read(b); err != nil {
return
}
rgCnt = UnpackInt(b)
if _, err = msgBuf.Write(b); err != nil {
return
}
// Get the size of record groups
b = make([]byte, PackedNumSize)
if _, err = buf.Read(b); err != nil {
return
}
bodySize = UnpackInt(b)
if _, err = msgBuf.Write(b); err != nil {
return
}
// And the record groups themselves.
if _, err = io.CopyN(msgBuf, buf, int64(bodySize+len(hdrBODYEND))); err != nil {
return
}
// Now validate the checksum before continuing.
if crc32.ChecksumIEEE(msgBuf.Bytes()) != r.Checksum {
err = ErrBadCksum
return
}
// Now that the checksum has validated, trim the msgBuf to only RGs.
// Skip over the BODYSTART, record group count, and record group size.
if _, err = msgBuf.Read(make([]byte, len(hdrBODYSTART)+(PackedNumSize*2))); err != nil {
return
}
// Then truncate.
msgBuf.Truncate(bodySize)
r.RecordGroups = make([]*ResponseRecordGroup, rgCnt)
for idx := 0; idx < rgCnt; idx++ {
rgBuf = new(bytes.Buffer)
// The RG unmarshaler handles the record count, but it needs to be in msgBuf to do that.
if _, err = io.CopyN(rgBuf, msgBuf, int64(PackedNumSize)); err != nil {
return
}
b = make([]byte, PackedNumSize)
if _, err = msgBuf.Read(b); err != nil {
return
}
if _, err = rgBuf.Write(b); err != nil {
return
}
rgSize = UnpackInt(b)
if _, err = io.CopyN(rgBuf, msgBuf, int64(rgSize)); err != nil {
return
}
r.RecordGroups[idx] = new(ResponseRecordGroup)
if err = r.RecordGroups[idx].UnmarshalBinary(rgBuf.Bytes()); err != nil {
return
}
}
_ = r.Size()
return
}
// getIdx is a NOOP for Messages, but is used for Model conformance.
func (r *Response) getIdx() (idx int) {
return
}
// getRecordGroups returns the RecordGroups in this Message.
func (r *Response) getRecordGroups() (recordGroups []RecordGroup) {
recordGroups = make([]RecordGroup, len(r.RecordGroups))
for idx, rg := range r.RecordGroups {
recordGroups[idx] = rg
}
return
}