go_sysutils/cryptparse/funcs_tlsflat.go
brent saner db20c70d86
v1.4.0
ADDED:
* cryptparse, which makes managing certs/keys MUCH much easier (and
  better) than the stdlib utilities.
2024-06-21 17:18:19 -04:00

218 lines
4.3 KiB
Go

package cryptparse
import (
`bytes`
`crypto`
`crypto/tls`
`crypto/x509`
`errors`
`fmt`
`net/url`
`os`
`strings`
`r00t2.io/sysutils/paths`
)
// Normalize ensures that all specified filepaths are absolute, etc.
func (t *TlsFlat) Normalize() (err error) {
if t.Certs != nil {
for _, c := range t.Certs {
if err = paths.RealPath(&c.CertFile); err != nil {
return
}
if c.KeyFile != nil {
if err = paths.RealPath(c.KeyFile); err != nil {
return
}
}
}
}
if t.CaFiles != nil {
for idx, _ := range t.CaFiles {
if err = paths.RealPath(&t.CaFiles[idx]); err != nil {
return
}
}
}
return
}
/*
ToTlsConfig returns a tls.Config from a TlsFlat. Note that it will have Normalize called on it.
Unfortunately it's not possible for this library to do the reverse, as CA certificates are not able to be extracted from an x509.CertPool.
*/
func (t *TlsFlat) ToTlsConfig() (tlsConf *tls.Config, err error) {
var b []byte
var rootCAs *x509.CertPool
var intermediateCAs []*x509.Certificate
var privKeys []crypto.PrivateKey
var tlsCerts []tls.Certificate
var parsedTlsCerts []tls.Certificate
var ciphers []uint16
var curves []tls.CurveID
var minVer uint16
var maxVer uint16
var buf *bytes.Buffer = new(bytes.Buffer)
var srvNm string = t.SniName
// Normalize any filepaths before validation.
if err = t.Normalize(); err != nil {
return
}
// And validate.
if err = validate.Struct(t); err != nil {
return
}
// CA cert(s).
buf.Reset()
if t.CaFiles != nil {
rootCAs = x509.NewCertPool()
for _, c := range t.CaFiles {
if b, err = os.ReadFile(c); err != nil {
if errors.Is(err, os.ErrNotExist) {
err = nil
continue
}
}
buf.Write(b)
}
if rootCAs, _, intermediateCAs, err = ParseCA(buf.Bytes()); err != nil {
return
}
} else {
if rootCAs, err = x509.SystemCertPool(); err != nil {
return
}
}
// Keys and Certs. They are assumed to be matched.
if t.Certs != nil {
for _, c := range t.Certs {
privKeys = nil
if c.KeyFile != nil {
if b, err = os.ReadFile(*c.KeyFile); err != nil {
return
}
if privKeys, err = ParsePrivateKey(b); err != nil {
return
}
}
if b, err = os.ReadFile(c.CertFile); err != nil {
return
}
if parsedTlsCerts, err = ParseLeafCert(b, privKeys, intermediateCAs...); err != nil {
return
}
tlsCerts = append(tlsCerts, parsedTlsCerts...)
}
}
// Ciphers.
if t.CipherSuites != nil {
ciphers = ParseTlsCiphers(strings.Join(t.CipherSuites, ","))
}
// Minimum TLS Protocol Version.
if t.MinTlsProtocol != nil {
if minVer, err = ParseTlsVersion(*t.MinTlsProtocol); err != nil {
return
}
}
// Maximum TLS Protocol Version.
if t.MaxTlsProtocol != nil {
if maxVer, err = ParseTlsVersion(*t.MaxTlsProtocol); err != nil {
return
}
}
// Curves.
if t.Curves != nil {
curves = ParseTlsCurves(strings.Join(t.Curves, ","))
}
tlsConf = &tls.Config{
Certificates: tlsCerts,
RootCAs: rootCAs,
ServerName: srvNm,
InsecureSkipVerify: t.SkipVerify,
CipherSuites: ciphers,
MinVersion: minVer,
MaxVersion: maxVer,
CurvePreferences: curves,
}
return
}
// ToTlsUri returns a TlsUri from a TlsFlat.
func (t *TlsFlat) ToTlsUri() (tlsUri *TlsUri, err error) {
var u *url.URL
if u, err = url.Parse(fmt.Sprintf("tls://%v/", t.SniName)); err != nil {
return
}
// CA cert(s).
if t.CaFiles != nil {
for _, c := range t.CaFiles {
u.Query().Add(TlsUriParamCa, c)
}
}
// Keys and Certs.
if t.Certs != nil {
for _, c := range t.Certs {
u.Query().Add(TlsUriParamCert, c.CertFile)
if c.KeyFile != nil {
u.Query().Add(TlsUriParamKey, *c.KeyFile)
}
}
}
// Enforce the SNI hostname.
u.Query().Add(TlsUriParamSni, t.SniName)
// Disable Verification.
if t.SkipVerify {
u.Query().Add(TlsUriParamNoVerify, "1")
}
// Ciphers.
if t.CipherSuites != nil {
for _, c := range t.CipherSuites {
u.Query().Add(TlsUriParamCipher, c)
}
}
// Minimum TLS Protocol Version.
if t.MinTlsProtocol != nil {
u.Query().Add(TlsUriParamMinTls, *t.MinTlsProtocol)
}
// Maximum TLS Protocol Version.
if t.MaxTlsProtocol != nil {
u.Query().Add(TlsUriParamMaxTls, *t.MaxTlsProtocol)
}
// Curves.
if t.Curves != nil {
for _, c := range t.Curves {
u.Query().Add(TlsUriParamCurve, c)
}
}
tlsUri = &TlsUri{
URL: u,
}
return
}