package main

import (
	`bytes`
	`crypto`
	`crypto/ecdh`
	`crypto/ecdsa`
	`crypto/ed25519`
	`crypto/elliptic`
	`crypto/rand`
	`crypto/rsa`
	`crypto/x509`
	`crypto/x509/pkix`
	`encoding/binary`
	`encoding/pem`
	`errors`
	`fmt`
	`log`
	`math/big`
	`os`
	`path/filepath`

	`github.com/brunoga/deep`
	`github.com/google/uuid`
)

/*
	getKeyFpath returns a (relative) path for a key PEM file (PKCS#8).

	This is used both when fetching from embed.FS and when generating new keys.
*/
func getKeyFpath(pairType, keyType string) (path string) {

	path = filepath.Join("_testdata", fmt.Sprintf("%s_%s_key.pem", pairType, keyType))

	return
}

/*
	getCertFpath returns a (relative) path for a cert PEM file.

	This is used both when fetching from embed.FS and when generating new certs.
*/
func getCertFpath(pairType, keyType string) (path string) {

	path = filepath.Join("_testdata", fmt.Sprintf("%s_%s_cert.pem", pairType, keyType))

	return
}

/*
	getChainFpath returns a (relative) path for a chained cert PEM file.

	This is used both when fetching from embed.FS and when generating new certs.
*/
func getChainFpath(pairType, keyType string) (path string) {

	path = filepath.Join("_testdata", fmt.Sprintf("%s_%s_cert_chained.pem", pairType, keyType))

	return
}

/*
	getCsrFpath returns a (relative) path for a CSR PEM file.

	This is used both when fetching from embed.FS and when generating new CSRs.
*/
func getCsrFpath(pairType, keyType string) (path string) {

	path = filepath.Join("_testdata", fmt.Sprintf("%s_%s_csr.pem", pairType, keyType))

	return
}

/*
	getKeypair takes cert type t and key type kt and returns the crypto.Private and crypto.Public keys for it.

	It assumes that loadKeys() at the *least* has already been called.
*/
func getKeypair(t, kt string) (priv crypto.PrivateKey, pub crypto.PublicKey) {

	priv = pairs[t].privKeys[kt]

	switch k := priv.(type) {
	case *ecdh.PrivateKey:
		pub = k.Public()
	case *ecdsa.PrivateKey:
		pub = k.Public()
	case ed25519.PrivateKey: // This is correct. Unlike other kt's, ed25519 doesn't use pointers.
		pub = k.Public()
	case *rsa.PrivateKey:
		pub = k.Public()
	}

	return
}

/*
	getSerial returns a (pseudo-)random certificate based on a UUIDv4 (RFC 4122 type 4).

	This guarantees not only that renewals (if issued/implemented) are reasonably guaranteed
	to be different from the past issuance but also that the serial issuance is non-sequential
	(both are common modern requirements of modern browser-trusted CAs; see
	https://cabforum.org/working-groups/server/baseline-requirements/documents/)
*/
func getSerial() (serial *big.Int) {

	var b []byte
	var n int64
	var u uuid.UUID = uuid.New()

	b = u[:]

	n = int64(binary.BigEndian.Uint64(b))
	// Serials must be positive.
	if n < 0 {
		n = -n
	}

	serial = big.NewInt(n)

	return
}

// getTpl returns a version of certificate template tpl with a randomized serial.
func getTpl(tpl *x509.Certificate) (newTpl *x509.Certificate) {

	newTpl = new(x509.Certificate)
	*newTpl = *tpl

	newTpl.SerialNumber = getSerial()

	return
}

// getSubj returns a cert/CSR-specific pkix.Name from a given cn (commonName).
func getSubj(cn string) (newSubj pkix.Name) {

	newSubj = deep.MustCopy(*pkixCommon)
	newSubj.CommonName = cn

	return
}

// loadKeys either loads from pems or generates and writes out the PEM keys.
func loadKeys() (err error) {

	var b []byte
	var t string
	var kt string
	var ok bool
	var pemBlock *pem.Block
	var keybuf *bytes.Buffer = new(bytes.Buffer)

	// Load in any existing keys.
	for _, t = range pairTypes {
		for _, kt = range keyTypes {
			log.Printf("Loading %s key %s\n", t, kt)
			if b, err = pems.ReadFile(getKeyFpath(t, kt)); err != nil {
				if errors.Is(err, os.ErrNotExist) {
					// Will generate missing below
					pairs[t] = &Pair{
						pairType:         t,
						keyBytes:         make(map[string][]byte),
						privKeys:         make(map[string]crypto.PrivateKey),
						certBytes:        make(map[string][]byte),
						certs:            make(map[string]*x509.Certificate),
						csrBytes:         make(map[string][]byte),
						csrs:             make(map[string]*x509.CertificateRequest),
						chainParentBytes: make(map[string][]byte),
						chainParent:      make(map[string]*x509.Certificate),
					}
					err = nil
					continue
				}
				return
			}
			if _, ok = pairs[t]; !ok {
				pairs[t] = &Pair{
					pairType:  t,
					keyBytes:  make(map[string][]byte),
					privKeys:  make(map[string]crypto.PrivateKey),
					certBytes: make(map[string][]byte),
					certs:     make(map[string]*x509.Certificate),
					csrBytes:  make(map[string][]byte),
					csrs:      make(map[string]*x509.CertificateRequest),
				}
			}
			pairs[t].keyBytes[kt] = b
			if pairs[t].privKeys[kt], err = x509.ParsePKCS8PrivateKey(b); err != nil {
				return
			}
		}
	}
	// Generate any missing keys.
	for _, t = range pairTypes {
		for _, kt = range keyTypes {
			if _, ok = pairs[t].privKeys[kt]; !ok {
				log.Printf("Generating %s key %s\n", t, kt)
				keybuf.Reset()
				switch kt {
				case "ecdh":
					if pairs[t].privKeys[kt], err = ecdh.X25519().GenerateKey(rand.Reader); err != nil {
						return
					}
				case "ecdsa":
					if pairs[t].privKeys[kt], err = ecdsa.GenerateKey(elliptic.P521(), rand.Reader); err != nil {
						return
					}
				case "ed25519":
					if _, pairs[t].privKeys[kt], err = ed25519.GenerateKey(rand.Reader); err != nil {
						return
					}
				case "rsa":
					if pairs[t].privKeys[kt], err = rsa.GenerateKey(rand.Reader, 4096); err != nil {
						return
					}
				}
				if b, err = x509.MarshalPKCS8PrivateKey(pairs[t].privKeys[kt]); err != nil {
					log.Panicln(err)
				}
				pemBlock = &pem.Block{
					Type:    "PRIVATE KEY",
					Headers: nil,
					Bytes:   b,
				}
				b = pem.EncodeToMemory(pemBlock)
				pairs[t].keyBytes[kt] = b
				if err = os.WriteFile(getKeyFpath(t, kt), b, 0o0600); err != nil {
					return
				}
			}
		}
	}

	return
}

// loadCerts combines all loadCert* functions in the proper order. It is expected that loadKeys has already been run.
func loadCerts() (err error) {

	var b []byte
	var t string
	var kt string
	var tkt [2]string
	var chainMissing [][2]string = make([][2]string, 0, (len(certgenOrder)-1)*len(keyTypes))

	if err = loadCertCa(); err != nil {
		return
	}
	if err = loadCertIssued(); err != nil {
		return
	}

	// And create chained certs of leaves so they fully validate.
	for _, t = range certgenOrder {
		if t == "inter" {
			continue // Don't bother with chaining the intermediate. If we play around with multiple intermediates, we will.
		}
		for _, kt = range keyTypes {
			pairs[t].chainParent[kt] = pairs[parents[t]].certs[kt]
			log.Printf("Loading %s chain %s\n", t, kt)
			if b, err = pems.ReadFile(getChainFpath(t, kt)); err != nil {
				if errors.Is(err, os.ErrNotExist) {
					err = nil
					chainMissing = append(chainMissing, [2]string{t, kt})
					continue
				}
				return
			}
			// Found
			pairs[t].chainParentBytes[kt] = b
		}
	}

	// "Generate" missing.
	for _, tkt = range chainMissing {
		t = tkt[0]
		kt = tkt[1]
		log.Printf("Building %s chain %s\n", t, kt)
		b = append(pairs[t].certBytes[kt], pairs[parents[t]].certBytes[kt]...)
		if err = os.WriteFile(getChainFpath(t, kt), b, 0o0600); err != nil {
			return
		}
	}

	return
}

// loadCertCa loads (or generates) the root CA/anchor. It is expected that loadKeys has already been run.
func loadCertCa() (err error) {

	var b []byte
	var kt string
	var ok bool
	var privKey crypto.PrivateKey
	var pubKey crypto.PublicKey
	var pemBlock *pem.Block
	var ktTpl *x509.Certificate

	for _, kt = range keyTypes {
		log.Printf("Loading CA certificate %s\n", kt)
		if b, err = pems.ReadFile(getCertFpath("ca", kt)); err != nil {
			if errors.Is(err, os.ErrNotExist) {
				// Will generate missing below.
				err = nil
				continue
			}
			return
		}
		// Assume the mapped Pair exists per loadKeys.
		pairs["ca"].certBytes[kt] = b
		pemBlock, _ = pem.Decode(b)
		if pairs["ca"].certs[kt], err = x509.ParseCertificate(pemBlock.Bytes); err != nil {
			return
		}
	}

	// Generate missing CA certs.
	for _, kt = range keyTypes {
		log.Printf("Generating CA certificate %s\n", kt)
		if _, ok = pairs["ca"].certs[kt]; !ok {
			ktTpl = getTpl(certTpl["ca"])
			privKey, pubKey = getKeypair("ca", kt)
			// Specifying the same cert template for both the template and parent params creates a self-signed.
			if b, err = x509.CreateCertificate(
				rand.Reader,
				ktTpl,
				ktTpl,
				pubKey,
				privKey,
			); err != nil {
				return
			}
			if pairs["ca"].certs[kt], err = x509.ParseCertificate(b); err != nil {
				return
			}
			pemBlock = &pem.Block{
				Type:    "CERTIFICATE",
				Headers: nil,
				Bytes:   b,
			}
			b = pem.EncodeToMemory(pemBlock)
			pairs["ca"].certBytes[kt] = b
			if err = os.WriteFile(getCertFpath("ca", kt), b, 0o0600); err != nil {
				return
			}
		}
	}

	return
}

// loadCertIssued handles the intermediate, "server" leaf, and "user" leaf.
func loadCertIssued() (err error) {

	var b []byte
	var ok bool
	var t string
	var kt string
	var tkt [2]string
	var ktMap map[string]bool
	var caCert *x509.Certificate
	var caPrivKey crypto.PrivateKey
	var certPrivKey crypto.PrivateKey
	var certPubKey crypto.PublicKey
	var pemBlock *pem.Block
	var ktTpl *x509.Certificate
	// map[<t>][<kt>]; map so we can condense dupes
	var certMissing map[string]map[string]bool = make(map[string]map[string]bool)
	var csrMissing [][2]string = make([][2]string, 0, len(certgenOrder)*len(keyTypes))

	// CSRS
	// Find existing CSRs and certs
	for _, t = range certgenOrder {
		for _, kt = range keyTypes {
			log.Printf("Loading %s CSR %s\n", t, kt)
			if b, err = pems.ReadFile(getCsrFpath(t, kt)); err != nil {
				if errors.Is(err, os.ErrNotExist) {
					err = nil
					csrMissing = append(csrMissing, [2]string{t, kt})
					continue
				}
				return
			}
			// Assume the mapped Pair exists per loadKeys.
			pairs[t].csrBytes[kt] = b
			pemBlock, _ = pem.Decode(b)
			if pairs[t].csrs[kt], err = x509.ParseCertificateRequest(pemBlock.Bytes); err != nil {
				return
			}
			log.Printf("Loading %s certificate %s\n", t, kt)
			if b, err = pems.ReadFile(getCertFpath(t, kt)); err != nil {
				if errors.Is(err, os.ErrNotExist) {
					err = nil
					if _, ok = certMissing[t]; !ok {
						certMissing[t] = make(map[string]bool)
					}
					certMissing[t][kt] = true
					continue
				}
			}
			pairs[t].certBytes[kt] = b
			pemBlock, _ = pem.Decode(b)
			if pairs[t].certs[kt], err = x509.ParseCertificate(pemBlock.Bytes); err != nil {
				return
			}
		}
	}

	// Generate missing CSRs.
	for _, tkt = range csrMissing {
		t = tkt[0]
		kt = tkt[1]
		log.Printf("Generating %s CSR %s\n", t, kt)
		certPrivKey, certPubKey = getKeypair(t, kt)
		if b, err = x509.CreateCertificateRequest(rand.Reader, csrs[t], certPrivKey); err != nil {
			return
		}
		if pairs[t].csrs[kt], err = x509.ParseCertificateRequest(b); err != nil {
			return
		}
		pemBlock = &pem.Block{
			Type:    "CERTIFICATE REQUEST",
			Headers: nil,
			Bytes:   b,
		}
		b = pem.EncodeToMemory(pemBlock)
		pairs[t].csrBytes[kt] = b
		if err = os.WriteFile(getCsrFpath(t, kt), b, 0o0600); err != nil {
			return
		}
		if _, ok = certMissing[t]; !ok {
			certMissing[t] = make(map[string]bool)
		}
		certMissing[t][kt] = true
	}

	// Force re-gen of certs for above new CSRs and gen missing.
	for _, t = range certgenOrder {
		if ktMap, ok = certMissing[t]; !ok {
			continue
		}
		for kt, _ = range ktMap {
			log.Printf("Generating %s certificate %s\n", t, kt)
			caCert = pairs[parents[t]].certs[kt]
			caPrivKey = pairs[parents[t]].privKeys[kt]
			_, certPubKey = getKeypair(t, kt)
			ktTpl = getTpl(certTpl[t])
			if b, err = x509.CreateCertificate(
				rand.Reader,
				ktTpl,
				caCert,
				certPubKey,
				caPrivKey,
			); err != nil {
				return
			}
			if pairs[t].certs[kt], err = x509.ParseCertificate(b); err != nil {
				return
			}
			pemBlock = &pem.Block{
				Type:    "CERTIFICATE",
				Headers: nil,
				Bytes:   b,
			}
			b = pem.EncodeToMemory(pemBlock)
			pairs[t].certBytes[kt] = b
			if err = os.WriteFile(getCertFpath(t, kt), b, 0o0600); err != nil {
				return
			}
		}
	}

	return
}