224 lines
4.4 KiB
Go
224 lines
4.4 KiB
Go
package cryptparse
|
|
|
|
import (
|
|
`bytes`
|
|
`crypto`
|
|
`crypto/tls`
|
|
`crypto/x509`
|
|
`errors`
|
|
`fmt`
|
|
`net/url`
|
|
`os`
|
|
`strings`
|
|
|
|
`r00t2.io/sysutils/paths`
|
|
)
|
|
|
|
// Normalize ensures that all specified filepaths are absolute, etc.
|
|
func (t *TlsFlat) Normalize() (err error) {
|
|
|
|
if t.Certs != nil {
|
|
for _, c := range t.Certs {
|
|
if err = paths.RealPath(&c.CertFile); err != nil {
|
|
return
|
|
}
|
|
if c.KeyFile != nil {
|
|
if err = paths.RealPath(c.KeyFile); err != nil {
|
|
return
|
|
}
|
|
}
|
|
}
|
|
}
|
|
if t.CaFiles != nil {
|
|
for idx, _ := range t.CaFiles {
|
|
if err = paths.RealPath(&t.CaFiles[idx]); err != nil {
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
/*
|
|
ToTlsConfig returns a tls.Config from a TlsFlat. Note that it will have Normalize called on it.
|
|
|
|
Unfortunately it's not possible for this library to do the reverse, as CA certificates are not able to be extracted from an x509.CertPool.
|
|
*/
|
|
func (t *TlsFlat) ToTlsConfig() (tlsConf *tls.Config, err error) {
|
|
|
|
var b []byte
|
|
var rootCAs *x509.CertPool
|
|
var intermediateCAs []*x509.Certificate
|
|
var privKeys []crypto.PrivateKey
|
|
var tlsCerts []tls.Certificate
|
|
var parsedTlsCerts []tls.Certificate
|
|
var ciphers []uint16
|
|
var curves []tls.CurveID
|
|
var minVer uint16
|
|
var maxVer uint16
|
|
var concatCAs []*x509.Certificate
|
|
var buf *bytes.Buffer = new(bytes.Buffer)
|
|
var srvNm string = t.SniName
|
|
|
|
// Normalize any filepaths before validation.
|
|
if err = t.Normalize(); err != nil {
|
|
return
|
|
}
|
|
|
|
// And validate.
|
|
if err = validate.Struct(t); err != nil {
|
|
return
|
|
}
|
|
|
|
// CA cert(s).
|
|
buf.Reset()
|
|
if t.CaFiles != nil {
|
|
rootCAs = x509.NewCertPool()
|
|
for _, c := range t.CaFiles {
|
|
if b, err = os.ReadFile(c); err != nil {
|
|
if errors.Is(err, os.ErrNotExist) {
|
|
err = nil
|
|
continue
|
|
}
|
|
}
|
|
buf.Write(b)
|
|
}
|
|
if rootCAs, _, intermediateCAs, err = ParseCA(buf.Bytes()); err != nil {
|
|
return
|
|
}
|
|
} else {
|
|
if rootCAs, err = x509.SystemCertPool(); err != nil {
|
|
return
|
|
}
|
|
}
|
|
|
|
// Keys and Certs. They are assumed to be matched.
|
|
if t.Certs != nil {
|
|
for _, c := range t.Certs {
|
|
privKeys = nil
|
|
if c.KeyFile != nil {
|
|
if b, err = os.ReadFile(*c.KeyFile); err != nil {
|
|
return
|
|
}
|
|
if privKeys, err = ParsePrivateKey(b); err != nil {
|
|
return
|
|
}
|
|
}
|
|
if b, err = os.ReadFile(c.CertFile); err != nil {
|
|
return
|
|
}
|
|
if parsedTlsCerts, concatCAs, err = ParseLeafCert(b, privKeys, intermediateCAs...); err != nil {
|
|
return
|
|
}
|
|
tlsCerts = append(tlsCerts, parsedTlsCerts...)
|
|
if concatCAs != nil {
|
|
for _, ca := range concatCAs {
|
|
rootCAs.AddCert(ca)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Ciphers.
|
|
if t.CipherSuites != nil {
|
|
ciphers = ParseTlsCiphers(strings.Join(t.CipherSuites, ","))
|
|
}
|
|
|
|
// Minimum TLS Protocol Version.
|
|
if t.MinTlsProtocol != nil {
|
|
if minVer, err = ParseTlsVersion(*t.MinTlsProtocol); err != nil {
|
|
return
|
|
}
|
|
}
|
|
|
|
// Maximum TLS Protocol Version.
|
|
if t.MaxTlsProtocol != nil {
|
|
if maxVer, err = ParseTlsVersion(*t.MaxTlsProtocol); err != nil {
|
|
return
|
|
}
|
|
}
|
|
|
|
// Curves.
|
|
if t.Curves != nil {
|
|
curves = ParseTlsCurves(strings.Join(t.Curves, ","))
|
|
}
|
|
|
|
tlsConf = &tls.Config{
|
|
Certificates: tlsCerts,
|
|
RootCAs: rootCAs,
|
|
ServerName: srvNm,
|
|
InsecureSkipVerify: t.SkipVerify,
|
|
CipherSuites: ciphers,
|
|
MinVersion: minVer,
|
|
MaxVersion: maxVer,
|
|
CurvePreferences: curves,
|
|
}
|
|
return
|
|
}
|
|
|
|
// ToTlsUri returns a TlsUri from a TlsFlat.
|
|
func (t *TlsFlat) ToTlsUri() (tlsUri *TlsUri, err error) {
|
|
|
|
var u *url.URL
|
|
|
|
if u, err = url.Parse(fmt.Sprintf("tls://%v/", t.SniName)); err != nil {
|
|
return
|
|
}
|
|
|
|
// CA cert(s).
|
|
if t.CaFiles != nil {
|
|
for _, c := range t.CaFiles {
|
|
u.Query().Add(ParamCa, c)
|
|
}
|
|
}
|
|
|
|
// Keys and Certs.
|
|
if t.Certs != nil {
|
|
for _, c := range t.Certs {
|
|
u.Query().Add(ParamCert, c.CertFile)
|
|
if c.KeyFile != nil {
|
|
u.Query().Add(ParamKey, *c.KeyFile)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Enforce the SNI hostname.
|
|
u.Query().Add(ParamSni, t.SniName)
|
|
|
|
// Disable Verification.
|
|
if t.SkipVerify {
|
|
u.Query().Add(ParamNoVerify, "1")
|
|
}
|
|
|
|
// Ciphers.
|
|
if t.CipherSuites != nil {
|
|
for _, c := range t.CipherSuites {
|
|
u.Query().Add(ParamCipher, c)
|
|
}
|
|
}
|
|
|
|
// Minimum TLS Protocol Version.
|
|
if t.MinTlsProtocol != nil {
|
|
u.Query().Add(ParamMinTls, *t.MinTlsProtocol)
|
|
}
|
|
|
|
// Maximum TLS Protocol Version.
|
|
if t.MaxTlsProtocol != nil {
|
|
u.Query().Add(ParamMaxTls, *t.MaxTlsProtocol)
|
|
}
|
|
|
|
// Curves.
|
|
if t.Curves != nil {
|
|
for _, c := range t.Curves {
|
|
u.Query().Add(ParamCurve, c)
|
|
}
|
|
}
|
|
|
|
tlsUri = &TlsUri{
|
|
URL: u,
|
|
}
|
|
|
|
return
|
|
}
|