431 lines
11 KiB
Go
431 lines
11 KiB
Go
package server
|
|
|
|
import (
|
|
`fmt`
|
|
`net`
|
|
`net/http`
|
|
`net/url`
|
|
`os`
|
|
`path/filepath`
|
|
`reflect`
|
|
`sort`
|
|
`strings`
|
|
|
|
`github.com/mileusna/useragent`
|
|
`r00t2.io/clientinfo/args`
|
|
`r00t2.io/goutils/logging`
|
|
`r00t2.io/goutils/multierr`
|
|
`r00t2.io/sysutils/paths`
|
|
)
|
|
|
|
// NewClient returns a R00tClient from a UA string.
|
|
func NewClient(uaStr string) (r *R00tClient, err error) {
|
|
|
|
var newR R00tClient
|
|
var ua useragent.UserAgent
|
|
|
|
if strings.TrimSpace(uaStr) == "" {
|
|
err = ErrEmptyUA
|
|
return
|
|
}
|
|
|
|
ua = useragent.Parse(uaStr)
|
|
|
|
if err = reflectClient(&ua, &newR); err != nil {
|
|
return
|
|
}
|
|
|
|
newR.ua = &ua
|
|
|
|
r = &newR
|
|
|
|
return
|
|
}
|
|
|
|
// NewServer returns a Server ready to use. Be sure to call Close to free up resources when done.
|
|
func NewServer(log logging.Logger, cliArgs *args.Args) (srv *Server, err error) {
|
|
|
|
var s Server
|
|
var udsSockPerms args.UdsPerms
|
|
|
|
if log == nil {
|
|
log = &logging.NullLogger{}
|
|
}
|
|
if cliArgs == nil {
|
|
err = ErrNoArgs
|
|
log.Err("server.NewServer: Received error creating server: %v", err)
|
|
return
|
|
}
|
|
|
|
s = Server{
|
|
log: log,
|
|
args: cliArgs,
|
|
mux: http.NewServeMux(),
|
|
sock: nil,
|
|
reloadChan: make(chan os.Signal),
|
|
stopChan: make(chan os.Signal),
|
|
}
|
|
|
|
s.mux.HandleFunc("/", s.handleDefault)
|
|
s.mux.HandleFunc("/about", s.handleAbout)
|
|
s.mux.HandleFunc("/about.html", s.handleAbout)
|
|
s.mux.HandleFunc("/usage", s.handleUsage)
|
|
s.mux.HandleFunc("/usage.html", s.handleUsage)
|
|
s.mux.HandleFunc("/favicon.ico", s.explicit404)
|
|
|
|
if s.listenUri, err = url.Parse(cliArgs.Listen.Listen); err != nil {
|
|
s.log.Err("server.NewServer: Failed to parse listener URI: %v", err)
|
|
return
|
|
}
|
|
s.listenUri.Scheme = strings.ToLower(s.listenUri.Scheme)
|
|
|
|
switch s.listenUri.Scheme {
|
|
case "unix":
|
|
if udsSockPerms, err = cliArgs.ModesAndOwners(); err != nil {
|
|
s.log.Err("server.NewServer: Failed to parse unix socket permissions: %v", err)
|
|
return
|
|
}
|
|
if err = paths.RealPath(&s.listenUri.Path); err != nil {
|
|
s.log.Err("server.NewServer: Failed to canonize/resolve socket path '%s': %v", s.listenUri.Path, err)
|
|
return
|
|
}
|
|
// Cleanup any stale socket.
|
|
if err = s.cleanup(true); err != nil {
|
|
s.log.Err("server.NewServer: Failed to cleanup for 'unix' listener: %v", err)
|
|
return
|
|
}
|
|
if err = os.MkdirAll(filepath.Dir(s.listenUri.Path), udsSockPerms.DMode); err != nil {
|
|
s.log.Err("server.NewServer: Received error creating socket directory '%s': %v", filepath.Dir(s.listenUri.Path), err)
|
|
return
|
|
}
|
|
|
|
if err = os.Chmod(filepath.Dir(s.listenUri.Path), udsSockPerms.DMode); err != nil {
|
|
s.log.Err("server.NewServer: Received error chmodding socket directory '%s': %v", filepath.Dir(s.listenUri.Path), err)
|
|
return
|
|
}
|
|
if err = os.Chown(filepath.Dir(s.listenUri.Path), udsSockPerms.UID, udsSockPerms.DGID); err != nil {
|
|
s.log.Err("server.NewServer: Received error chowning socket directory '%s': %v", filepath.Dir(s.listenUri.Path), err)
|
|
return
|
|
}
|
|
if s.listenUri, err = url.Parse(
|
|
fmt.Sprintf(
|
|
"%s://%s",
|
|
s.listenUri.Scheme, s.listenUri.Path,
|
|
),
|
|
); err != nil {
|
|
s.log.Err("server.NewServer: Failed to re-parse listener URI: %v", err)
|
|
return
|
|
}
|
|
if s.sock, err = net.Listen("unix", s.listenUri.Path); err != nil {
|
|
s.log.Err("server.NewServer: Failed to open socket on '%s': %v", s.listenUri.Path, err)
|
|
}
|
|
if err = os.Chmod(s.listenUri.Path, udsSockPerms.FMode); err != nil {
|
|
s.log.Err("server.NewServer: Received error chmodding socket '%s': %v", filepath.Dir(s.listenUri.Path), err)
|
|
return
|
|
}
|
|
if err = os.Chown(s.listenUri.Path, udsSockPerms.UID, udsSockPerms.FGID); err != nil {
|
|
s.log.Err("server.NewServer: Received error chowning socket '%s': %v", filepath.Dir(s.listenUri.Path), err)
|
|
return
|
|
}
|
|
case "http", "tcp":
|
|
s.isHttp = s.listenUri.Scheme == "http"
|
|
if err = s.cleanup(true); err != nil {
|
|
s.log.Err("server.NewServer: Failed to cleanup for '%s' listener: %v", strings.ToUpper(s.listenUri.Scheme), err)
|
|
return
|
|
}
|
|
if s.listenUri, err = url.Parse(
|
|
fmt.Sprintf(
|
|
"%s://%s%s",
|
|
s.listenUri.Scheme, s.listenUri.Host, s.listenUri.Path,
|
|
),
|
|
); err != nil {
|
|
s.log.Err("server.NewServer: Failed to re-parse listener URI: %v", err)
|
|
return
|
|
}
|
|
if s.sock, err = net.Listen("tcp", s.listenUri.Host); err != nil {
|
|
s.log.Err("server.NewServer: Failed to open %s socket on '%s': %v", strings.ToUpper(s.listenUri.Scheme), s.listenUri.Host, err)
|
|
return
|
|
}
|
|
default:
|
|
s.log.Err("server.NewServer: Unsupported scheme: %v", s.listenUri.Scheme)
|
|
err = ErrInvalidScheme
|
|
return
|
|
}
|
|
cliArgs.Listen.Listen = s.listenUri.String()
|
|
|
|
srv = &s
|
|
|
|
return
|
|
}
|
|
|
|
/*
|
|
decideParseAccept takes the slice returned from parseAccept, runs parseAccept on it,
|
|
and chooses based on what MIME types are supported by this program.
|
|
err will be an ErrUnsupportedMIME if no supported MIME type is found.
|
|
If parsed is nil or empty, format will be defFormat and err will be nil.
|
|
*/
|
|
func decideParseAccept(parsed []*parsedMIME, defFormat string) (format string, err error) {
|
|
|
|
var customFmtFound bool
|
|
|
|
if parsed == nil || len(parsed) == 0 {
|
|
format = defFormat
|
|
return
|
|
}
|
|
|
|
for _, pf := range parsed {
|
|
switch pf.MIME {
|
|
case "*/*": // Client explicitly accept anything
|
|
format = defFormat
|
|
customFmtFound = true
|
|
case "application/*": // Use JSON
|
|
format = mediaJSON
|
|
customFmtFound = true
|
|
case "text/*": // Use HTML
|
|
format = mediaHTML
|
|
customFmtFound = true
|
|
case mediaHTML, mediaJSON, mediaXML, mediaYAML:
|
|
format = pf.MIME
|
|
customFmtFound = true
|
|
}
|
|
if customFmtFound {
|
|
break
|
|
}
|
|
}
|
|
|
|
if !customFmtFound {
|
|
format = defFormat
|
|
err = ErrUnsupportedMIME
|
|
return
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
/*
|
|
reflectClient takes a src and dst and attempts to set/convert src to dst. It is *VERY STRICT*.
|
|
It is expected that src does NOT use pointers.
|
|
...This is pretty much just custom-made for converting a useragent.UserAgent to a R00tClient.
|
|
Don't use it for anything else.
|
|
*/
|
|
func reflectClient(src, dst any) (err error) {
|
|
|
|
var dstField reflect.StructField
|
|
var dstFieldVal reflect.Value
|
|
var srcFieldVal reflect.Value
|
|
var srcField string
|
|
var ok bool
|
|
var intVal *int
|
|
var strVal *string
|
|
var boolVal *bool
|
|
var srcVal reflect.Value = reflect.ValueOf(src)
|
|
var dstVal reflect.Value = reflect.ValueOf(dst)
|
|
|
|
// Both must be ptrs to a struct
|
|
if srcVal.Kind() != reflect.Ptr || dstVal.Kind() != reflect.Ptr {
|
|
err = ErrPtrNeeded
|
|
return
|
|
}
|
|
|
|
srcVal = srcVal.Elem()
|
|
dstVal = dstVal.Elem()
|
|
|
|
/*
|
|
Now that we have the underlying type/value of the ptr above,
|
|
check for structs.
|
|
*/
|
|
if srcVal.Kind() != reflect.Struct || dstVal.Kind() != reflect.Struct {
|
|
err = ErrStructNeeded
|
|
return
|
|
}
|
|
|
|
for i := 0; i < dstVal.NumField(); i++ {
|
|
dstField = dstVal.Type().Field(i)
|
|
dstFieldVal = dstVal.Field(i)
|
|
|
|
// Skip unexported
|
|
if !dstFieldVal.CanSet() {
|
|
continue
|
|
}
|
|
srcField = dstField.Tag.Get(convertTag)
|
|
// Skip explicitly skipped (<convertTag>:"-")
|
|
if srcField == "-" {
|
|
continue
|
|
}
|
|
// If no explicit field name is present, set it to the dst field name.
|
|
if _, ok = dstField.Tag.Lookup(convertTag); !ok {
|
|
srcField = dstField.Name
|
|
}
|
|
// Get the value from src
|
|
srcFieldVal = srcVal.FieldByName(srcField)
|
|
// Skip invalid...
|
|
if !srcFieldVal.IsValid() {
|
|
continue
|
|
}
|
|
// And zero-value.
|
|
if reflect.DeepEqual(srcFieldVal.Interface(), reflect.Zero(srcFieldVal.Type()).Interface()) {
|
|
continue
|
|
}
|
|
// Structs need to recurse.
|
|
if dstFieldVal.Kind() == reflect.Ptr && dstFieldVal.Type().Elem().Kind() == reflect.Struct {
|
|
// Ensure we don't have a nil ptr
|
|
if dstFieldVal.IsNil() {
|
|
dstFieldVal.Set(reflect.New(dstFieldVal.Type().Elem()))
|
|
}
|
|
// And recurse into it.
|
|
if err = reflectClient(srcFieldVal.Addr().Interface(), dstFieldVal.Interface()); err != nil {
|
|
return
|
|
}
|
|
} else {
|
|
// Everything else gets assigned here.
|
|
switch dstFieldVal.Kind() {
|
|
case reflect.Bool:
|
|
if srcFieldVal.Kind() == reflect.Bool {
|
|
dstFieldVal.Set(reflect.ValueOf(srcFieldVal.Interface().(bool)))
|
|
} else {
|
|
err = ErrIncompatFieldType
|
|
return
|
|
}
|
|
case reflect.String:
|
|
if srcFieldVal.Kind() == reflect.String {
|
|
dstFieldVal.Set(reflect.ValueOf(srcFieldVal.Interface().(string)))
|
|
} else {
|
|
err = ErrIncompatFieldType
|
|
return
|
|
}
|
|
case reflect.Int:
|
|
if srcFieldVal.Kind() == reflect.Int {
|
|
dstFieldVal.Set(reflect.ValueOf(srcFieldVal.Interface().(int)))
|
|
} else {
|
|
err = ErrIncompatFieldType
|
|
return
|
|
}
|
|
case reflect.Ptr:
|
|
// Pointers to above
|
|
switch dstFieldVal.Type().Elem().Kind() {
|
|
case reflect.Bool:
|
|
if srcFieldVal.Kind() == reflect.Bool {
|
|
boolVal = new(bool)
|
|
*boolVal = srcFieldVal.Interface().(bool)
|
|
dstFieldVal.Set(reflect.ValueOf(boolVal))
|
|
} else {
|
|
err = ErrIncompatFieldType
|
|
return
|
|
}
|
|
case reflect.String:
|
|
if srcFieldVal.Kind() == reflect.String {
|
|
strVal = new(string)
|
|
*strVal = srcFieldVal.Interface().(string)
|
|
dstFieldVal.Set(reflect.ValueOf(strVal))
|
|
} else {
|
|
err = ErrIncompatFieldType
|
|
return
|
|
}
|
|
case reflect.Int:
|
|
if srcFieldVal.Kind() == reflect.Int {
|
|
intVal = new(int)
|
|
*intVal = srcFieldVal.Interface().(int)
|
|
dstFieldVal.Set(reflect.ValueOf(intVal))
|
|
} else {
|
|
err = ErrIncompatFieldType
|
|
return
|
|
}
|
|
default:
|
|
err = ErrUnhandledField
|
|
return
|
|
}
|
|
default:
|
|
err = ErrUnhandledField
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
// parseAccept parses an Accept header as per RFC 9110 § 12.5.1.
|
|
func parseAccept(hdrVal string) (parsed []*parsedMIME, err error) {
|
|
|
|
var mimes []string
|
|
var parts []string
|
|
var params []string
|
|
var paramsLen int
|
|
var kv []string
|
|
var mt *parsedMIME
|
|
var mErr *multierr.MultiError = multierr.NewMultiError(nil)
|
|
|
|
if hdrVal == "" {
|
|
return
|
|
}
|
|
|
|
mimes = strings.Split(hdrVal, ",")
|
|
|
|
for _, mime := range mimes {
|
|
mt = &parsedMIME{
|
|
MIME: "",
|
|
Weight: 1.0, // between 0.0 and 1.0
|
|
Params: nil,
|
|
}
|
|
mime = strings.TrimSpace(mime)
|
|
// Split into []string{<type>[, <param>, ...]}
|
|
parts = strings.Split(mime, ";")
|
|
if parts == nil || len(parts) < 1 {
|
|
mErr.AddError(ErrInvalidAccept)
|
|
continue
|
|
}
|
|
if parts[0] == "" {
|
|
mErr.AddError(ErrInvalidAccept)
|
|
continue
|
|
}
|
|
if len(strings.Split(parts[0], "/")) != 2 {
|
|
mErr.AddError(ErrInvalidAccept)
|
|
continue
|
|
}
|
|
mt.MIME = strings.TrimSpace(parts[0])
|
|
if len(parts) > 1 {
|
|
// Parameters were provided. We don't really use them except `q`, but...
|
|
params = parts[1:]
|
|
paramsLen = len(params)
|
|
for idx, param := range params {
|
|
param = strings.TrimSpace(param)
|
|
kv = strings.SplitN(param, "=", 2)
|
|
if len(kv) != 2 {
|
|
mErr.AddError(ErrInvalidAccept)
|
|
continue
|
|
}
|
|
if kv[0] == "q" && idx == paramsLen-1 {
|
|
// It's the weight. RFC's pretty clear it's the last param.
|
|
fmt.Sscanf(kv[1], "%f", &mt.Weight)
|
|
if mt.Weight > 1.0 || mt.Weight < 0.0 {
|
|
mErr.AddError(ErrInvalidAccept)
|
|
continue
|
|
}
|
|
} else {
|
|
if mt.Params == nil {
|
|
mt.Params = make(map[string]string)
|
|
}
|
|
mt.Params[kv[0]] = kv[1]
|
|
}
|
|
}
|
|
}
|
|
parsed = append(parsed, mt)
|
|
}
|
|
|
|
// Now sort by weight (descending).
|
|
sort.SliceStable(
|
|
parsed,
|
|
func(i, j int) (isBefore bool) {
|
|
isBefore = parsed[i].Weight > parsed[j].Weight
|
|
return
|
|
},
|
|
)
|
|
|
|
if !mErr.IsEmpty() {
|
|
err = mErr
|
|
return
|
|
}
|
|
|
|
return
|
|
}
|