go_clientinfo/server/funcs.go
2024-12-12 02:22:54 -05:00

431 lines
11 KiB
Go

package server
import (
`fmt`
`net`
`net/http`
`net/url`
`os`
`path/filepath`
`reflect`
`sort`
`strings`
`github.com/mileusna/useragent`
`r00t2.io/clientinfo/args`
`r00t2.io/goutils/logging`
`r00t2.io/goutils/multierr`
`r00t2.io/sysutils/paths`
)
// NewClient returns a R00tClient from a UA string.
func NewClient(uaStr string) (r *R00tClient, err error) {
var newR R00tClient
var ua useragent.UserAgent
if strings.TrimSpace(uaStr) == "" {
err = ErrEmptyUA
return
}
ua = useragent.Parse(uaStr)
if err = reflectClient(&ua, &newR); err != nil {
return
}
newR.ua = &ua
r = &newR
return
}
// NewServer returns a Server ready to use. Be sure to call Close to free up resources when done.
func NewServer(log logging.Logger, cliArgs *args.Args) (srv *Server, err error) {
var s Server
var udsSockPerms args.UdsPerms
if log == nil {
log = &logging.NullLogger{}
}
if cliArgs == nil {
err = ErrNoArgs
log.Err("server.NewServer: Received error creating server: %v", err)
return
}
s = Server{
log: log,
args: cliArgs,
mux: http.NewServeMux(),
sock: nil,
reloadChan: make(chan os.Signal),
stopChan: make(chan os.Signal),
}
s.mux.HandleFunc("/", s.handleDefault)
s.mux.HandleFunc("/about", s.handleAbout)
s.mux.HandleFunc("/about.html", s.handleAbout)
s.mux.HandleFunc("/usage", s.handleUsage)
s.mux.HandleFunc("/usage.html", s.handleUsage)
s.mux.HandleFunc("/favicon.ico", s.explicit404)
if s.listenUri, err = url.Parse(cliArgs.Listen.Listen); err != nil {
s.log.Err("server.NewServer: Failed to parse listener URI: %v", err)
return
}
s.listenUri.Scheme = strings.ToLower(s.listenUri.Scheme)
switch s.listenUri.Scheme {
case "unix":
if udsSockPerms, err = cliArgs.ModesAndOwners(); err != nil {
s.log.Err("server.NewServer: Failed to parse unix socket permissions: %v", err)
return
}
if err = paths.RealPath(&s.listenUri.Path); err != nil {
s.log.Err("server.NewServer: Failed to canonize/resolve socket path '%s': %v", s.listenUri.Path, err)
return
}
// Cleanup any stale socket.
if err = s.cleanup(true); err != nil {
s.log.Err("server.NewServer: Failed to cleanup for 'unix' listener: %v", err)
return
}
if err = os.MkdirAll(filepath.Dir(s.listenUri.Path), udsSockPerms.DMode); err != nil {
s.log.Err("server.NewServer: Received error creating socket directory '%s': %v", filepath.Dir(s.listenUri.Path), err)
return
}
if err = os.Chmod(filepath.Dir(s.listenUri.Path), udsSockPerms.DMode); err != nil {
s.log.Err("server.NewServer: Received error chmodding socket directory '%s': %v", filepath.Dir(s.listenUri.Path), err)
return
}
if err = os.Chown(filepath.Dir(s.listenUri.Path), udsSockPerms.UID, udsSockPerms.DGID); err != nil {
s.log.Err("server.NewServer: Received error chowning socket directory '%s': %v", filepath.Dir(s.listenUri.Path), err)
return
}
if s.listenUri, err = url.Parse(
fmt.Sprintf(
"%s://%s",
s.listenUri.Scheme, s.listenUri.Path,
),
); err != nil {
s.log.Err("server.NewServer: Failed to re-parse listener URI: %v", err)
return
}
if s.sock, err = net.Listen("unix", s.listenUri.Path); err != nil {
s.log.Err("server.NewServer: Failed to open socket on '%s': %v", s.listenUri.Path, err)
}
if err = os.Chmod(s.listenUri.Path, udsSockPerms.FMode); err != nil {
s.log.Err("server.NewServer: Received error chmodding socket '%s': %v", filepath.Dir(s.listenUri.Path), err)
return
}
if err = os.Chown(s.listenUri.Path, udsSockPerms.UID, udsSockPerms.FGID); err != nil {
s.log.Err("server.NewServer: Received error chowning socket '%s': %v", filepath.Dir(s.listenUri.Path), err)
return
}
case "http", "tcp":
s.isHttp = s.listenUri.Scheme == "http"
if err = s.cleanup(true); err != nil {
s.log.Err("server.NewServer: Failed to cleanup for '%s' listener: %v", strings.ToUpper(s.listenUri.Scheme), err)
return
}
if s.listenUri, err = url.Parse(
fmt.Sprintf(
"%s://%s%s",
s.listenUri.Scheme, s.listenUri.Host, s.listenUri.Path,
),
); err != nil {
s.log.Err("server.NewServer: Failed to re-parse listener URI: %v", err)
return
}
if s.sock, err = net.Listen("tcp", s.listenUri.Host); err != nil {
s.log.Err("server.NewServer: Failed to open %s socket on '%s': %v", strings.ToUpper(s.listenUri.Scheme), s.listenUri.Host, err)
return
}
default:
s.log.Err("server.NewServer: Unsupported scheme: %v", s.listenUri.Scheme)
err = ErrInvalidScheme
return
}
cliArgs.Listen.Listen = s.listenUri.String()
srv = &s
return
}
/*
decideParseAccept takes the slice returned from parseAccept, runs parseAccept on it,
and chooses based on what MIME types are supported by this program.
err will be an ErrUnsupportedMIME if no supported MIME type is found.
If parsed is nil or empty, format will be defFormat and err will be nil.
*/
func decideParseAccept(parsed []*parsedMIME, defFormat string) (format string, err error) {
var customFmtFound bool
if parsed == nil || len(parsed) == 0 {
format = defFormat
return
}
for _, pf := range parsed {
switch pf.MIME {
case "*/*": // Client explicitly accept anything
format = defFormat
customFmtFound = true
case "application/*": // Use JSON
format = mediaJSON
customFmtFound = true
case "text/*": // Use HTML
format = mediaHTML
customFmtFound = true
case mediaHTML, mediaJSON, mediaXML, mediaYAML:
format = pf.MIME
customFmtFound = true
}
if customFmtFound {
break
}
}
if !customFmtFound {
format = defFormat
err = ErrUnsupportedMIME
return
}
return
}
/*
reflectClient takes a src and dst and attempts to set/convert src to dst. It is *VERY STRICT*.
It is expected that src does NOT use pointers.
...This is pretty much just custom-made for converting a useragent.UserAgent to a R00tClient.
Don't use it for anything else.
*/
func reflectClient(src, dst any) (err error) {
var dstField reflect.StructField
var dstFieldVal reflect.Value
var srcFieldVal reflect.Value
var srcField string
var ok bool
var intVal *int
var strVal *string
var boolVal *bool
var srcVal reflect.Value = reflect.ValueOf(src)
var dstVal reflect.Value = reflect.ValueOf(dst)
// Both must be ptrs to a struct
if srcVal.Kind() != reflect.Ptr || dstVal.Kind() != reflect.Ptr {
err = ErrPtrNeeded
return
}
srcVal = srcVal.Elem()
dstVal = dstVal.Elem()
/*
Now that we have the underlying type/value of the ptr above,
check for structs.
*/
if srcVal.Kind() != reflect.Struct || dstVal.Kind() != reflect.Struct {
err = ErrStructNeeded
return
}
for i := 0; i < dstVal.NumField(); i++ {
dstField = dstVal.Type().Field(i)
dstFieldVal = dstVal.Field(i)
// Skip unexported
if !dstFieldVal.CanSet() {
continue
}
srcField = dstField.Tag.Get(convertTag)
// Skip explicitly skipped (<convertTag>:"-")
if srcField == "-" {
continue
}
// If no explicit field name is present, set it to the dst field name.
if _, ok = dstField.Tag.Lookup(convertTag); !ok {
srcField = dstField.Name
}
// Get the value from src
srcFieldVal = srcVal.FieldByName(srcField)
// Skip invalid...
if !srcFieldVal.IsValid() {
continue
}
// And zero-value.
if reflect.DeepEqual(srcFieldVal.Interface(), reflect.Zero(srcFieldVal.Type()).Interface()) {
continue
}
// Structs need to recurse.
if dstFieldVal.Kind() == reflect.Ptr && dstFieldVal.Type().Elem().Kind() == reflect.Struct {
// Ensure we don't have a nil ptr
if dstFieldVal.IsNil() {
dstFieldVal.Set(reflect.New(dstFieldVal.Type().Elem()))
}
// And recurse into it.
if err = reflectClient(srcFieldVal.Addr().Interface(), dstFieldVal.Interface()); err != nil {
return
}
} else {
// Everything else gets assigned here.
switch dstFieldVal.Kind() {
case reflect.Bool:
if srcFieldVal.Kind() == reflect.Bool {
dstFieldVal.Set(reflect.ValueOf(srcFieldVal.Interface().(bool)))
} else {
err = ErrIncompatFieldType
return
}
case reflect.String:
if srcFieldVal.Kind() == reflect.String {
dstFieldVal.Set(reflect.ValueOf(srcFieldVal.Interface().(string)))
} else {
err = ErrIncompatFieldType
return
}
case reflect.Int:
if srcFieldVal.Kind() == reflect.Int {
dstFieldVal.Set(reflect.ValueOf(srcFieldVal.Interface().(int)))
} else {
err = ErrIncompatFieldType
return
}
case reflect.Ptr:
// Pointers to above
switch dstFieldVal.Type().Elem().Kind() {
case reflect.Bool:
if srcFieldVal.Kind() == reflect.Bool {
boolVal = new(bool)
*boolVal = srcFieldVal.Interface().(bool)
dstFieldVal.Set(reflect.ValueOf(boolVal))
} else {
err = ErrIncompatFieldType
return
}
case reflect.String:
if srcFieldVal.Kind() == reflect.String {
strVal = new(string)
*strVal = srcFieldVal.Interface().(string)
dstFieldVal.Set(reflect.ValueOf(strVal))
} else {
err = ErrIncompatFieldType
return
}
case reflect.Int:
if srcFieldVal.Kind() == reflect.Int {
intVal = new(int)
*intVal = srcFieldVal.Interface().(int)
dstFieldVal.Set(reflect.ValueOf(intVal))
} else {
err = ErrIncompatFieldType
return
}
default:
err = ErrUnhandledField
return
}
default:
err = ErrUnhandledField
return
}
}
}
return
}
// parseAccept parses an Accept header as per RFC 9110 § 12.5.1.
func parseAccept(hdrVal string) (parsed []*parsedMIME, err error) {
var mimes []string
var parts []string
var params []string
var paramsLen int
var kv []string
var mt *parsedMIME
var mErr *multierr.MultiError = multierr.NewMultiError(nil)
if hdrVal == "" {
return
}
mimes = strings.Split(hdrVal, ",")
for _, mime := range mimes {
mt = &parsedMIME{
MIME: "",
Weight: 1.0, // between 0.0 and 1.0
Params: nil,
}
mime = strings.TrimSpace(mime)
// Split into []string{<type>[, <param>, ...]}
parts = strings.Split(mime, ";")
if parts == nil || len(parts) < 1 {
mErr.AddError(ErrInvalidAccept)
continue
}
if parts[0] == "" {
mErr.AddError(ErrInvalidAccept)
continue
}
if len(strings.Split(parts[0], "/")) != 2 {
mErr.AddError(ErrInvalidAccept)
continue
}
mt.MIME = strings.TrimSpace(parts[0])
if len(parts) > 1 {
// Parameters were provided. We don't really use them except `q`, but...
params = parts[1:]
paramsLen = len(params)
for idx, param := range params {
param = strings.TrimSpace(param)
kv = strings.SplitN(param, "=", 2)
if len(kv) != 2 {
mErr.AddError(ErrInvalidAccept)
continue
}
if kv[0] == "q" && idx == paramsLen-1 {
// It's the weight. RFC's pretty clear it's the last param.
fmt.Sscanf(kv[1], "%f", &mt.Weight)
if mt.Weight > 1.0 || mt.Weight < 0.0 {
mErr.AddError(ErrInvalidAccept)
continue
}
} else {
if mt.Params == nil {
mt.Params = make(map[string]string)
}
mt.Params[kv[0]] = kv[1]
}
}
}
parsed = append(parsed, mt)
}
// Now sort by weight (descending).
sort.SliceStable(
parsed,
func(i, j int) (isBefore bool) {
isBefore = parsed[i].Weight > parsed[j].Weight
return
},
)
if !mErr.IsEmpty() {
err = mErr
return
}
return
}