package wireproto import ( `bytes` `cmp` `fmt` `hash/crc32` `io` `strings` `github.com/google/uuid` ) // ConnId returns a copy of the connection ID. func (r *Request) ConnId() (conn uuid.UUID) { conn = r.connId return } // MarshalBinary renders a Request into a byte-packed format. func (r *Request) MarshalBinary() (data []byte, err error) { var b []byte var msgSize int var cksum uint32 var buf *bytes.Buffer = new(bytes.Buffer) var msgBuf *bytes.Buffer = new(bytes.Buffer) _ = r.Size() for _, i := range r.RecordGroups { msgSize += i.Size() } // The message "body" - this is done first so the body can be checksummed (if not nil). if _, err = msgBuf.Write(hdrBODYSTART); err != nil { return } // Record group count if _, err = msgBuf.Write(PackInt(len(r.RecordGroups))); err != nil { return } // And size. if _, err = msgBuf.Write(PackInt(msgSize)); err != nil { return } for _, i := range r.RecordGroups { if b, err = i.MarshalBinary(); err != nil { return } if _, err = msgBuf.Write(b); err != nil { return } } if _, err = msgBuf.Write(hdrBODYEND); err != nil { return } // Now the surrounding request. // Checksum - update and serialize if not null. if r.Checksum != nil { cksum = crc32.ChecksumIEEE(msgBuf.Bytes()) *r.Checksum = cksum if _, err = buf.Write(hdrCKSUM); err != nil { return } if _, err = buf.Write(cksumBytes(*r.Checksum)); err != nil { return } } // Message start if _, err = buf.Write(hdrMSGSTART); err != nil { return } // Protocol version if _, err = buf.Write(PackUint32(r.ProtocolVersion)); err != nil { return } // Then copy the msgBuf in. if _, err = msgBuf.WriteTo(buf); err != nil { return } // And then the message end. if _, err = buf.Write(hdrMSGEND); err != nil { return } data = buf.Bytes() return } // Model returns an indented string representation of the model. func (r *Request) Model() (out string) { out = r.ModelCustom(IndentChars, SeparatorChars, indentR) return } func (r *Request) ModelCustom(indent, sep string, level uint) (out string) { var maxFtr int var size int var sb strings.Builder for _, rg := range r.RecordGroups { size += rg.Size() } // Checksum (optional for Request) sb.WriteString(strings.Repeat(indent, int(level))) if r.Checksum == nil { sb.WriteString(strings.Repeat("-", 8)) sb.WriteString(sep) sb.WriteString("// (No Checksum Present)\n") } else { // HDR: CKSUM sb.WriteString(padBytesRight(hdrCKSUM, 8)) sb.WriteString(sep) sb.WriteString("// HDR:CKSUM\n") // Checksum sb.WriteString(strings.Repeat(indent, int(level))) sb.WriteString(padBytesRight(cksumBytes(*r.Checksum), 8)) sb.WriteString(sep) sb.WriteString(fmt.Sprintf("// Checksum Value (%d)\n", *r.Checksum)) } // Header: MSGSTART sb.WriteString(strings.Repeat(indent, int(level))) sb.WriteString(padBytesRight(hdrMSGSTART, 8)) sb.WriteString(sep) sb.WriteString("// HDR:MSGSTART\n") // Protocol Version sb.WriteString(strings.Repeat(indent, int(level))) sb.WriteString(padIntRight(int(r.ProtocolVersion), 8)) sb.WriteString(sep) sb.WriteString(fmt.Sprintf("// Protocol Version (%d)\n", r.ProtocolVersion)) // Header: BODYSTART sb.WriteString(strings.Repeat(indent, int(level))) sb.WriteString(padBytesRight(hdrBODYSTART, 8)) sb.WriteString(sep) sb.WriteString("// HDR:BODYSTART\n") // Count sb.WriteString(strings.Repeat(indent, int(level))) sb.WriteString(padIntRight(len(r.RecordGroups), 8)) sb.WriteString(sep) sb.WriteString(fmt.Sprintf("// Record Group Count (%d)\n", len(r.RecordGroups))) // Size sb.WriteString(strings.Repeat(indent, int(level))) sb.WriteString(padIntRight(size, 8)) sb.WriteString(sep) sb.WriteString(fmt.Sprintf("// Record Groups Size (%d)\n", size)) // VALUES for idx, rg := range r.RecordGroups { sb.WriteString(strings.Repeat(indent, int(level))) sb.WriteString(fmt.Sprintf("// Record Group %d (%d bytes)\n", idx+1, rg.Size())) sb.WriteString(rg.ModelCustom(indent, sep, level+1)) } // Make the footers a little more nicely aligned. switch cmp.Compare(len(hdrBODYEND), len(hdrMSGEND)) { case -1: maxFtr = len(hdrMSGEND) case 1, 0: maxFtr = len(hdrBODYEND) } // Footer: BODYEND sb.WriteString(strings.Repeat(indent, int(level))) sb.WriteString(padBytesRight(hdrBODYEND, maxFtr)) sb.WriteString(sep) sb.WriteString("// HDR:BODYEND\n") // Footer: MSGEND sb.WriteString(strings.Repeat(indent, int(level))) sb.WriteString(padBytesRight(hdrMSGEND, maxFtr)) sb.WriteString(sep) sb.WriteString("// HDR:MSGEND\n") out = sb.String() return } // Resolve associates children with parents. func (r *Request) Resolve() { for idx, i := range r.RecordGroups { i.parent = r i.rgIdx = idx i.Resolve() } } // SetConnID allows for setting a connection ID. This is largely just used for debugging purposes. func (r *Request) SetConnID(connId uuid.UUID) { r.connId = connId for _, rg := range r.RecordGroups { rg.connId = connId for _, rec := range rg.Records { rec.connId = connId for _, kvp := range rec.Pairs { kvp.connId = &connId } } } } // Size returns the Request's calculated size (in bytes) and updates the size field if 0. func (r *Request) Size() (size int) { if r == nil { return } // Checksum if r.Checksum != nil { size += len(hdrCKSUM) size += CksumPackedSize } // Message header size += len(hdrMSGSTART) // Protocol version size += PackedNumSize // Count and Size uint32s size += PackedNumSize * 2 // Message begin size += len(hdrBODYSTART) for _, p := range r.RecordGroups { size += p.Size() } // Message end size += len(hdrBODYEND) // And closing sequence. size += len(hdrMSGEND) if r.common == nil { r.common = new(common) } r.size = uint32(size) return } // ToMap returns a slice of slice of slice of FVP maps for this Message. func (r *Request) ToMap() (m [][][]map[string]interface{}) { m = make([][][]map[string]interface{}, len(r.RecordGroups)) for idx, rg := range r.RecordGroups { m[idx] = rg.ToMap() } return } // UnmarshalBinary populates a Request from packed bytes. func (r *Request) UnmarshalBinary(data []byte) (err error) { if data == nil || len(data) == 0 { return } var b []byte var rgCnt, bodySize int var rgSize int var rgBuf *bytes.Buffer var buf *bytes.Reader = bytes.NewReader(data) var msgBuf *bytes.Buffer = new(bytes.Buffer) if r == nil { *r = Request{} } r.common = &common{} r.size = 0 // Check for a checksum. b = make([]byte, len(hdrMSGSTART)) if _, err = buf.Read(b); err != nil { return } if bytes.Equal(b, hdrCKSUM) { // A checksum header was found. b = make([]byte, CksumPackedSize) if _, err = buf.Read(b); err != nil { return } r.Checksum = new(uint32) *r.Checksum = byteOrder.Uint32(b) // Since only the checksum has been read, now read the MSGSTART... b = make([]byte, len(hdrMSGSTART)) if _, err = buf.Read(b); err != nil { return } // But don't need to do anything with it. } else { // MSGSTART has already been read as part of the checksum check. r.Checksum = nil } // Get the protocol version. b = make([]byte, PackedNumSize) if _, err = buf.Read(b); err != nil { return } r.ProtocolVersion = UnpackUint32(b) // Skip over the BODYSTART (but write it to msgBuf). if _, err = io.CopyN(msgBuf, buf, int64(len(hdrBODYSTART))); err != nil { return } // Get the count of record groups b = make([]byte, PackedNumSize) if _, err = buf.Read(b); err != nil { return } rgCnt = UnpackInt(b) if _, err = msgBuf.Write(b); err != nil { return } // And their size b = make([]byte, PackedNumSize) if _, err = buf.Read(b); err != nil { return } bodySize = UnpackInt(b) if _, err = msgBuf.Write(b); err != nil { return } // And the record groups themselves. if _, err = io.CopyN(msgBuf, buf, int64(bodySize+len(hdrBODYEND))); err != nil { return } // Validate the checksum. if r.Checksum != nil { if crc32.ChecksumIEEE(msgBuf.Bytes()) != *r.Checksum { err = ErrBadCksum return } } // Now that the checksum (if provided) has validated, trim the msgBuf to only RGs. // Skip over the BODYSTART, record group count, and record group size. if _, err = msgBuf.Read(make([]byte, len(hdrBODYSTART)+(PackedNumSize*2))); err != nil { return } // Then truncate. msgBuf.Truncate(bodySize) r.RecordGroups = make([]*RequestRecordGroup, rgCnt) for idx := 0; idx < rgCnt; idx++ { rgBuf = new(bytes.Buffer) // The RG unmarshaler handles the record count, but it needs to be put in msgBuf to do that. if _, err = io.CopyN(rgBuf, msgBuf, int64(PackedNumSize)); err != nil { return } b = make([]byte, PackedNumSize) if _, err = msgBuf.Read(b); err != nil { return } if _, err = rgBuf.Write(b); err != nil { return } rgSize = UnpackInt(b) if _, err = io.CopyN(rgBuf, msgBuf, int64(rgSize)); err != nil { return } r.RecordGroups[idx] = new(RequestRecordGroup) if err = r.RecordGroups[idx].UnmarshalBinary(rgBuf.Bytes()); err != nil { return } } _ = r.Size() return } // getIdx is a NOOP for Messages, but is used for Model conformance. func (r *Request) getIdx() (idx int) { return } // getRecordGroups returns the RecordGroups in this Message. func (r *Request) getRecordGroups() (recordGroups []RecordGroup) { recordGroups = make([]RecordGroup, len(r.RecordGroups)) for idx, rg := range r.RecordGroups { recordGroups[idx] = rg } return }