package server import ( `fmt` `net` `net/http` `net/url` `os` `path/filepath` `reflect` `sort` `strings` `github.com/mileusna/useragent` `r00t2.io/clientinfo/args` `r00t2.io/goutils/logging` `r00t2.io/goutils/multierr` `r00t2.io/sysutils/paths` ) // NewClient returns a R00tClient from a UA string. func NewClient(uaStr string) (r *R00tClient, err error) { var newR R00tClient var ua useragent.UserAgent if strings.TrimSpace(uaStr) == "" { err = ErrEmptyUA return } ua = useragent.Parse(uaStr) if err = reflectClient(&ua, &newR); err != nil { return } newR.ua = &ua r = &newR return } // NewServer returns a Server ready to use. Be sure to call Close to free up resources when done. func NewServer(log logging.Logger, cliArgs *args.Args) (srv *Server, err error) { var s Server var udsSockPerms args.UdsPerms if log == nil { log = &logging.NullLogger{} } if cliArgs == nil { err = ErrNoArgs log.Err("server.NewServer: Received error creating server: %v", err) return } s = Server{ log: log, args: cliArgs, mux: http.NewServeMux(), sock: nil, reloadChan: make(chan os.Signal), stopChan: make(chan os.Signal), } s.mux.HandleFunc("/", s.handleDefault) s.mux.HandleFunc("/about", s.handleAbout) s.mux.HandleFunc("/about.html", s.handleAbout) s.mux.HandleFunc("/usage", s.handleUsage) s.mux.HandleFunc("/usage.html", s.handleUsage) s.mux.HandleFunc("/favicon.ico", s.explicit404) if s.listenUri, err = url.Parse(cliArgs.Listen.Listen); err != nil { s.log.Err("server.NewServer: Failed to parse listener URI: %v", err) return } s.listenUri.Scheme = strings.ToLower(s.listenUri.Scheme) switch s.listenUri.Scheme { case "unix": if udsSockPerms, err = cliArgs.ModesAndOwners(); err != nil { s.log.Err("server.NewServer: Failed to parse unix socket permissions: %v", err) return } if err = paths.RealPath(&s.listenUri.Path); err != nil { s.log.Err("server.NewServer: Failed to canonize/resolve socket path '%s': %v", s.listenUri.Path, err) return } // Cleanup any stale socket. if err = s.cleanup(true); err != nil { s.log.Err("server.NewServer: Failed to cleanup for 'unix' listener: %v", err) return } if err = os.MkdirAll(filepath.Dir(s.listenUri.Path), udsSockPerms.DMode); err != nil { s.log.Err("server.NewServer: Received error creating socket directory '%s': %v", filepath.Dir(s.listenUri.Path), err) return } if err = os.Chmod(filepath.Dir(s.listenUri.Path), udsSockPerms.DMode); err != nil { s.log.Err("server.NewServer: Received error chmodding socket directory '%s': %v", filepath.Dir(s.listenUri.Path), err) return } if err = os.Chown(filepath.Dir(s.listenUri.Path), udsSockPerms.UID, udsSockPerms.DGID); err != nil { s.log.Err("server.NewServer: Received error chowning socket directory '%s': %v", filepath.Dir(s.listenUri.Path), err) return } if s.listenUri, err = url.Parse( fmt.Sprintf( "%s://%s", s.listenUri.Scheme, s.listenUri.Path, ), ); err != nil { s.log.Err("server.NewServer: Failed to re-parse listener URI: %v", err) return } if s.sock, err = net.Listen("unix", s.listenUri.Path); err != nil { s.log.Err("server.NewServer: Failed to open socket on '%s': %v", s.listenUri.Path, err) } if err = os.Chmod(s.listenUri.Path, udsSockPerms.FMode); err != nil { s.log.Err("server.NewServer: Received error chmodding socket '%s': %v", filepath.Dir(s.listenUri.Path), err) return } if err = os.Chown(s.listenUri.Path, udsSockPerms.UID, udsSockPerms.FGID); err != nil { s.log.Err("server.NewServer: Received error chowning socket '%s': %v", filepath.Dir(s.listenUri.Path), err) return } case "http", "tcp": s.isHttp = s.listenUri.Scheme == "http" if err = s.cleanup(true); err != nil { s.log.Err("server.NewServer: Failed to cleanup for '%s' listener: %v", strings.ToUpper(s.listenUri.Scheme), err) return } if s.listenUri, err = url.Parse( fmt.Sprintf( "%s://%s%s", s.listenUri.Scheme, s.listenUri.Host, s.listenUri.Path, ), ); err != nil { s.log.Err("server.NewServer: Failed to re-parse listener URI: %v", err) return } if s.sock, err = net.Listen("tcp", s.listenUri.Host); err != nil { s.log.Err("server.NewServer: Failed to open %s socket on '%s': %v", strings.ToUpper(s.listenUri.Scheme), s.listenUri.Host, err) return } default: s.log.Err("server.NewServer: Unsupported scheme: %v", s.listenUri.Scheme) err = ErrInvalidScheme return } cliArgs.Listen.Listen = s.listenUri.String() srv = &s return } /* decideParseAccept takes the slice returned from parseAccept, runs parseAccept on it, and chooses based on what MIME types are supported by this program. err will be an ErrUnsupportedMIME if no supported MIME type is found. If parsed is nil or empty, format will be defFormat and err will be nil. */ func decideParseAccept(parsed []*parsedMIME, defFormat string) (format string, err error) { var customFmtFound bool if parsed == nil || len(parsed) == 0 { format = defFormat return } for _, pf := range parsed { switch pf.MIME { case "*/*": // Client explicitly accept anything format = defFormat customFmtFound = true case "application/*": // Use JSON format = mediaJSON customFmtFound = true case "text/*": // Use HTML format = mediaHTML customFmtFound = true case mediaHTML, mediaJSON, mediaXML, mediaYAML: format = pf.MIME customFmtFound = true } if customFmtFound { break } } if !customFmtFound { format = defFormat err = ErrUnsupportedMIME return } return } /* reflectClient takes a src and dst and attempts to set/convert src to dst. It is *VERY STRICT*. It is expected that src does NOT use pointers. ...This is pretty much just custom-made for converting a useragent.UserAgent to a R00tClient. Don't use it for anything else. */ func reflectClient(src, dst any) (err error) { var dstField reflect.StructField var dstFieldVal reflect.Value var srcFieldVal reflect.Value var srcField string var ok bool var intVal *int var strVal *string var boolVal *bool var srcVal reflect.Value = reflect.ValueOf(src) var dstVal reflect.Value = reflect.ValueOf(dst) // Both must be ptrs to a struct if srcVal.Kind() != reflect.Ptr || dstVal.Kind() != reflect.Ptr { err = ErrPtrNeeded return } srcVal = srcVal.Elem() dstVal = dstVal.Elem() /* Now that we have the underlying type/value of the ptr above, check for structs. */ if srcVal.Kind() != reflect.Struct || dstVal.Kind() != reflect.Struct { err = ErrStructNeeded return } for i := 0; i < dstVal.NumField(); i++ { dstField = dstVal.Type().Field(i) dstFieldVal = dstVal.Field(i) // Skip unexported if !dstFieldVal.CanSet() { continue } srcField = dstField.Tag.Get(convertTag) // Skip explicitly skipped (:"-") if srcField == "-" { continue } // If no explicit field name is present, set it to the dst field name. if _, ok = dstField.Tag.Lookup(convertTag); !ok { srcField = dstField.Name } // Get the value from src srcFieldVal = srcVal.FieldByName(srcField) // Skip invalid... if !srcFieldVal.IsValid() { continue } // And zero-value. if reflect.DeepEqual(srcFieldVal.Interface(), reflect.Zero(srcFieldVal.Type()).Interface()) { continue } // Structs need to recurse. if dstFieldVal.Kind() == reflect.Ptr && dstFieldVal.Type().Elem().Kind() == reflect.Struct { // Ensure we don't have a nil ptr if dstFieldVal.IsNil() { dstFieldVal.Set(reflect.New(dstFieldVal.Type().Elem())) } // And recurse into it. if err = reflectClient(srcFieldVal.Addr().Interface(), dstFieldVal.Interface()); err != nil { return } } else { // Everything else gets assigned here. switch dstFieldVal.Kind() { case reflect.Bool: if srcFieldVal.Kind() == reflect.Bool { dstFieldVal.Set(reflect.ValueOf(srcFieldVal.Interface().(bool))) } else { err = ErrIncompatFieldType return } case reflect.String: if srcFieldVal.Kind() == reflect.String { dstFieldVal.Set(reflect.ValueOf(srcFieldVal.Interface().(string))) } else { err = ErrIncompatFieldType return } case reflect.Int: if srcFieldVal.Kind() == reflect.Int { dstFieldVal.Set(reflect.ValueOf(srcFieldVal.Interface().(int))) } else { err = ErrIncompatFieldType return } case reflect.Ptr: // Pointers to above switch dstFieldVal.Type().Elem().Kind() { case reflect.Bool: if srcFieldVal.Kind() == reflect.Bool { boolVal = new(bool) *boolVal = srcFieldVal.Interface().(bool) dstFieldVal.Set(reflect.ValueOf(boolVal)) } else { err = ErrIncompatFieldType return } case reflect.String: if srcFieldVal.Kind() == reflect.String { strVal = new(string) *strVal = srcFieldVal.Interface().(string) dstFieldVal.Set(reflect.ValueOf(strVal)) } else { err = ErrIncompatFieldType return } case reflect.Int: if srcFieldVal.Kind() == reflect.Int { intVal = new(int) *intVal = srcFieldVal.Interface().(int) dstFieldVal.Set(reflect.ValueOf(intVal)) } else { err = ErrIncompatFieldType return } default: err = ErrUnhandledField return } default: err = ErrUnhandledField return } } } return } // parseAccept parses an Accept header as per RFC 9110 ยง 12.5.1. func parseAccept(hdrVal string) (parsed []*parsedMIME, err error) { var mimes []string var parts []string var params []string var paramsLen int var kv []string var mt *parsedMIME var mErr *multierr.MultiError = multierr.NewMultiError(nil) if hdrVal == "" { return } mimes = strings.Split(hdrVal, ",") for _, mime := range mimes { mt = &parsedMIME{ MIME: "", Weight: 1.0, // between 0.0 and 1.0 Params: nil, } mime = strings.TrimSpace(mime) // Split into []string{[, , ...]} parts = strings.Split(mime, ";") if parts == nil || len(parts) < 1 { mErr.AddError(ErrInvalidAccept) continue } if parts[0] == "" { mErr.AddError(ErrInvalidAccept) continue } if len(strings.Split(parts[0], "/")) != 2 { mErr.AddError(ErrInvalidAccept) continue } mt.MIME = strings.TrimSpace(parts[0]) if len(parts) > 1 { // Parameters were provided. We don't really use them except `q`, but... params = parts[1:] paramsLen = len(params) for idx, param := range params { param = strings.TrimSpace(param) kv = strings.SplitN(param, "=", 2) if len(kv) != 2 { mErr.AddError(ErrInvalidAccept) continue } if kv[0] == "q" && idx == paramsLen-1 { // It's the weight. RFC's pretty clear it's the last param. fmt.Sscanf(kv[1], "%f", &mt.Weight) if mt.Weight > 1.0 || mt.Weight < 0.0 { mErr.AddError(ErrInvalidAccept) continue } } else { if mt.Params == nil { mt.Params = make(map[string]string) } mt.Params[kv[0]] = kv[1] } } } parsed = append(parsed, mt) } // Now sort by weight (descending). sort.SliceStable( parsed, func(i, j int) (isBefore bool) { isBefore = parsed[i].Weight > parsed[j].Weight return }, ) if !mErr.IsEmpty() { err = mErr return } return }