diff --git a/auth.go b/auth.go index c6421ce..a9c6a66 100644 --- a/auth.go +++ b/auth.go @@ -18,9 +18,13 @@ import ( "slices" ) -// verifyHandshakeSignature verifies a signature against pre-hashed -// (if required) handshake contents. +// verifyHandshakeSignature verifies a signature against unhashed handshake contents. func verifyHandshakeSignature(sigType uint8, pubkey crypto.PublicKey, hashFunc crypto.Hash, signed, sig []byte) error { + if hashFunc != directSigning { + h := hashFunc.New() + h.Write(signed) + signed = h.Sum(nil) + } switch sigType { case signatureECDSA: pubKey, ok := pubkey.(*ecdsa.PublicKey) @@ -61,6 +65,32 @@ func verifyHandshakeSignature(sigType uint8, pubkey crypto.PublicKey, hashFunc c return nil } +// verifyLegacyHandshakeSignature verifies a TLS 1.0 and 1.1 signature against +// pre-hashed handshake contents. +func verifyLegacyHandshakeSignature(sigType uint8, pubkey crypto.PublicKey, hashFunc crypto.Hash, hashed, sig []byte) error { + switch sigType { + case signatureECDSA: + pubKey, ok := pubkey.(*ecdsa.PublicKey) + if !ok { + return fmt.Errorf("expected an ECDSA public key, got %T", pubkey) + } + if !ecdsa.VerifyASN1(pubKey, hashed, sig) { + return errors.New("ECDSA verification failure") + } + case signaturePKCS1v15: + pubKey, ok := pubkey.(*rsa.PublicKey) + if !ok { + return fmt.Errorf("expected an RSA public key, got %T", pubkey) + } + if err := rsa.VerifyPKCS1v15(pubKey, hashFunc, hashed, sig); err != nil { + return err + } + default: + return errors.New("internal error: unknown signature type") + } + return nil +} + const ( serverSignatureContext = "TLS 1.3, server CertificateVerify\x00" clientSignatureContext = "TLS 1.3, client CertificateVerify\x00" @@ -77,21 +107,15 @@ var signaturePadding = []byte{ 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, } -// signedMessage returns the pre-hashed (if necessary) message to be signed by -// certificate keys in TLS 1.3. See RFC 8446, Section 4.4.3. -func signedMessage(sigHash crypto.Hash, context string, transcript hash.Hash) []byte { - if sigHash == directSigning { - b := &bytes.Buffer{} - b.Write(signaturePadding) - io.WriteString(b, context) - b.Write(transcript.Sum(nil)) - return b.Bytes() - } - h := sigHash.New() - h.Write(signaturePadding) - io.WriteString(h, context) - h.Write(transcript.Sum(nil)) - return h.Sum(nil) +// signedMessage returns the (unhashed) message to be signed by certificate keys +// in TLS 1.3. See RFC 8446, Section 4.4.3. +func signedMessage(context string, transcript hash.Hash) []byte { + const maxSize = 64 /* signaturePadding */ + len(serverSignatureContext) + 512/8 /* SHA-512 */ + b := bytes.NewBuffer(make([]byte, 0, maxSize)) + b.Write(signaturePadding) + io.WriteString(b, context) + b.Write(transcript.Sum(nil)) + return b.Bytes() } // typeAndHashFromSignatureScheme returns the corresponding signature type and @@ -149,90 +173,78 @@ func legacyTypeAndHashFromPublicKey(pub crypto.PublicKey) (sigType uint8, hash c var rsaSignatureSchemes = []struct { scheme SignatureScheme minModulusBytes int - maxVersion uint16 }{ // RSA-PSS is used with PSSSaltLengthEqualsHash, and requires // emLen >= hLen + sLen + 2 - {PSSWithSHA256, crypto.SHA256.Size()*2 + 2, VersionTLS13}, - {PSSWithSHA384, crypto.SHA384.Size()*2 + 2, VersionTLS13}, - {PSSWithSHA512, crypto.SHA512.Size()*2 + 2, VersionTLS13}, + {PSSWithSHA256, crypto.SHA256.Size()*2 + 2}, + {PSSWithSHA384, crypto.SHA384.Size()*2 + 2}, + {PSSWithSHA512, crypto.SHA512.Size()*2 + 2}, // PKCS #1 v1.5 uses prefixes from hashPrefixes in crypto/rsa, and requires // emLen >= len(prefix) + hLen + 11 - // TLS 1.3 dropped support for PKCS #1 v1.5 in favor of RSA-PSS. - {PKCS1WithSHA256, 19 + crypto.SHA256.Size() + 11, VersionTLS12}, - {PKCS1WithSHA384, 19 + crypto.SHA384.Size() + 11, VersionTLS12}, - {PKCS1WithSHA512, 19 + crypto.SHA512.Size() + 11, VersionTLS12}, - {PKCS1WithSHA1, 15 + crypto.SHA1.Size() + 11, VersionTLS12}, + {PKCS1WithSHA256, 19 + crypto.SHA256.Size() + 11}, + {PKCS1WithSHA384, 19 + crypto.SHA384.Size() + 11}, + {PKCS1WithSHA512, 19 + crypto.SHA512.Size() + 11}, + {PKCS1WithSHA1, 15 + crypto.SHA1.Size() + 11}, } -// signatureSchemesForCertificate returns the list of supported SignatureSchemes -// for a given certificate, based on the public key and the protocol version, -// and optionally filtered by its explicit SupportedSignatureAlgorithms. -func signatureSchemesForCertificate(version uint16, cert *Certificate) []SignatureScheme { - priv, ok := cert.PrivateKey.(crypto.Signer) - if !ok { - return nil - } - - var sigAlgs []SignatureScheme - switch pub := priv.Public().(type) { +func signatureSchemesForPublicKey(version uint16, pub crypto.PublicKey) []SignatureScheme { + switch pub := pub.(type) { case *ecdsa.PublicKey: - if version != VersionTLS13 { + if version < VersionTLS13 { // In TLS 1.2 and earlier, ECDSA algorithms are not // constrained to a single curve. - sigAlgs = []SignatureScheme{ + return []SignatureScheme{ ECDSAWithP256AndSHA256, ECDSAWithP384AndSHA384, ECDSAWithP521AndSHA512, ECDSAWithSHA1, } - break } switch pub.Curve { case elliptic.P256(): - sigAlgs = []SignatureScheme{ECDSAWithP256AndSHA256} + return []SignatureScheme{ECDSAWithP256AndSHA256} case elliptic.P384(): - sigAlgs = []SignatureScheme{ECDSAWithP384AndSHA384} + return []SignatureScheme{ECDSAWithP384AndSHA384} case elliptic.P521(): - sigAlgs = []SignatureScheme{ECDSAWithP521AndSHA512} + return []SignatureScheme{ECDSAWithP521AndSHA512} default: return nil } case *rsa.PublicKey: size := pub.Size() - sigAlgs = make([]SignatureScheme, 0, len(rsaSignatureSchemes)) + sigAlgs := make([]SignatureScheme, 0, len(rsaSignatureSchemes)) for _, candidate := range rsaSignatureSchemes { - if size >= candidate.minModulusBytes && version <= candidate.maxVersion { + if size >= candidate.minModulusBytes { sigAlgs = append(sigAlgs, candidate.scheme) } } + return sigAlgs case ed25519.PublicKey: - sigAlgs = []SignatureScheme{Ed25519} + return []SignatureScheme{Ed25519} default: return nil } - - if cert.SupportedSignatureAlgorithms != nil { - sigAlgs = slices.DeleteFunc(sigAlgs, func(sigAlg SignatureScheme) bool { - return !isSupportedSignatureAlgorithm(sigAlg, cert.SupportedSignatureAlgorithms) - }) - } - - // Filter out any unsupported signature algorithms, for example due to - // FIPS 140-3 policy, tlssha1=0, or any downstream changes to defaults.go. - supportedAlgs := supportedSignatureAlgorithms(version) - sigAlgs = slices.DeleteFunc(sigAlgs, func(sigAlg SignatureScheme) bool { - return !isSupportedSignatureAlgorithm(sigAlg, supportedAlgs) - }) - - return sigAlgs } // selectSignatureScheme picks a SignatureScheme from the peer's preference list // that works with the selected certificate. It's only called for protocol // versions that support signature algorithms, so TLS 1.2 and 1.3. func selectSignatureScheme(vers uint16, c *Certificate, peerAlgs []SignatureScheme) (SignatureScheme, error) { - supportedAlgs := signatureSchemesForCertificate(vers, c) + priv, ok := c.PrivateKey.(crypto.Signer) + if !ok { + return 0, unsupportedCertificateError(c) + } + supportedAlgs := signatureSchemesForPublicKey(vers, priv.Public()) + if c.SupportedSignatureAlgorithms != nil { + supportedAlgs = slices.DeleteFunc(supportedAlgs, func(sigAlg SignatureScheme) bool { + return !isSupportedSignatureAlgorithm(sigAlg, c.SupportedSignatureAlgorithms) + }) + } + // Filter out any unsupported signature algorithms, for example due to + // FIPS 140-3 policy, tlssha1=0, or protocol version. + supportedAlgs = slices.DeleteFunc(supportedAlgs, func(sigAlg SignatureScheme) bool { + return isDisabledSignatureAlgorithm(vers, sigAlg, false) + }) if len(supportedAlgs) == 0 { return 0, unsupportedCertificateError(c) } diff --git a/cipher_suites.go b/cipher_suites.go index e74afb2..3cb9fa3 100644 --- a/cipher_suites.go +++ b/cipher_suites.go @@ -146,8 +146,8 @@ type cipherSuite struct { } var cipherSuites = []*cipherSuite{ // TODO: replace with a map, since the order doesn't matter. - {TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, 32, 0, 12, ecdheRSAKA, suiteECDHE | suiteTLS12, nil, nil, aeadChaCha20Poly1305}, - {TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, 32, 0, 12, ecdheECDSAKA, suiteECDHE | suiteECSign | suiteTLS12, nil, nil, aeadChaCha20Poly1305}, + {TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, 32, 0, 12, ecdheRSAKA, suiteECDHE | suiteTLS12, nil, nil, aeadChaCha20Poly1305}, + {TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256, 32, 0, 12, ecdheECDSAKA, suiteECDHE | suiteECSign | suiteTLS12, nil, nil, aeadChaCha20Poly1305}, {TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, 16, 0, 4, ecdheRSAKA, suiteECDHE | suiteTLS12, nil, nil, aeadAESGCM}, {TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, 16, 0, 4, ecdheECDSAKA, suiteECDHE | suiteECSign | suiteTLS12, nil, nil, aeadAESGCM}, {TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, 32, 0, 4, ecdheRSAKA, suiteECDHE | suiteTLS12 | suiteSHA384, nil, nil, aeadAESGCM}, @@ -281,7 +281,7 @@ var cipherSuitesPreferenceOrder = []uint16{ // AEADs w/ ECDHE TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, - TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, + TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256, TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, // CBC w/ ECDHE TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, @@ -310,7 +310,7 @@ var cipherSuitesPreferenceOrder = []uint16{ var cipherSuitesPreferenceOrderNoAES = []uint16{ // ChaCha20Poly1305 - TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, + TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256, TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, // AES-GCM w/ ECDHE TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, diff --git a/common.go b/common.go index 28f4af0..94dfd66 100644 --- a/common.go +++ b/common.go @@ -20,6 +20,7 @@ import ( "fmt" "io" "net" + "runtime" "slices" "strings" "sync" @@ -145,19 +146,31 @@ const ( type CurveID uint16 const ( - CurveP256 CurveID = 23 - CurveP384 CurveID = 24 - CurveP521 CurveID = 25 - X25519 CurveID = 29 - X25519MLKEM768 CurveID = 4588 + CurveP256 CurveID = 23 + CurveP384 CurveID = 24 + CurveP521 CurveID = 25 + X25519 CurveID = 29 + X25519MLKEM768 CurveID = 4588 + SecP256r1MLKEM768 CurveID = 4587 + SecP384r1MLKEM1024 CurveID = 4589 ) func isTLS13OnlyKeyExchange(curve CurveID) bool { - return curve == X25519MLKEM768 + switch curve { + case X25519MLKEM768, SecP256r1MLKEM768, SecP384r1MLKEM1024: + return true + default: + return false + } } func isPQKeyExchange(curve CurveID) bool { - return curve == X25519MLKEM768 + switch curve { + case X25519MLKEM768, SecP256r1MLKEM768, SecP384r1MLKEM1024: + return true + default: + return false + } } // TLS 1.3 Key Share. See RFC 8446, Section 4.2.8. @@ -304,11 +317,16 @@ type ConnectionState struct { // client side. ECHAccepted bool + // HelloRetryRequest indicates whether we sent a HelloRetryRequest if we + // are a server, or if we received a HelloRetryRequest if we are a client. + HelloRetryRequest bool + // ekm is a closure exposed via ExportKeyingMaterial. ekm func(label string, context []byte, length int) ([]byte, error) - // testingOnlyDidHRR is true if a HelloRetryRequest was sent/received. - testingOnlyDidHRR bool + // testingOnlyPeerSignatureAlgorithm is the signature algorithm used by the + // peer to sign the handshake. It is not set for resumed connections. + testingOnlyPeerSignatureAlgorithm SignatureScheme } // ExportKeyingMaterial returns length bytes of exported key material in a new @@ -465,6 +483,10 @@ type ClientHelloInfo struct { // connection to fail. Conn net.Conn + // HelloRetryRequest indicates whether the ClientHello was sent in response + // to a HelloRetryRequest message. + HelloRetryRequest bool + // config is embedded by the GetCertificate or GetConfigForClient caller, // for use with SupportsCertificate. config *Config @@ -636,10 +658,13 @@ type Config struct { // If GetConfigForClient is nil, the Config passed to Server() will be // used for all connections. // - // If SessionTicketKey was explicitly set on the returned Config, or if - // SetSessionTicketKeys was called on the returned Config, those keys will + // If SessionTicketKey is explicitly set on the returned Config, or if + // SetSessionTicketKeys is called on the returned Config, those keys will // be used. Otherwise, the original Config keys will be used (and possibly - // rotated if they are automatically managed). + // rotated if they are automatically managed). WARNING: this allows session + // resumption of connections originally established with the parent (or a + // sibling) Config, which may bypass the [Config.VerifyPeerCertificate] + // value of the returned Config. GetConfigForClient func(*ClientHelloInfo) (*Config, error) // VerifyPeerCertificate, if not nil, is called after normal @@ -657,8 +682,10 @@ type Config struct { // rawCerts may be empty on the server if ClientAuth is RequestClientCert or // VerifyClientCertIfGiven. // - // This callback is not invoked on resumed connections, as certificates are - // not re-verified on resumption. + // This callback is not invoked on resumed connections. WARNING: this + // includes connections resumed across Configs returned by [Config.Clone] or + // [Config.GetConfigForClient] and their parents. If that is not intended, + // use [Config.VerifyConnection] instead, or set [Config.SessionTicketsDisabled]. // // verifiedChains and its contents should not be modified. VerifyPeerCertificate func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error @@ -803,6 +830,11 @@ type Config struct { // From Go 1.24, the default includes the [X25519MLKEM768] hybrid // post-quantum key exchange. To disable it, set CurvePreferences explicitly // or use the GODEBUG=tlsmlkem=0 environment variable. + // + // From Go 1.26, the default includes the [SecP256r1MLKEM768] and + // [SecP384r1MLKEM1024] hybrid post-quantum key exchanges, too. To disable + // them, set CurvePreferences explicitly or use either the + // GODEBUG=tlsmlkem=0 or the GODEBUG=tlssecpmlkem=0 environment variable. CurvePreferences []CurveID // DynamicRecordSizingDisabled disables adaptive sizing of TLS records. @@ -818,7 +850,7 @@ type Config struct { // KeyLogWriter optionally specifies a destination for TLS master secrets // in NSS key log format that can be used to allow external programs // such as Wireshark to decrypt TLS connections. - // See https://developer.mozilla.org/en-US/docs/Mozilla/Projects/NSS/Key_Log_Format. + // See https://datatracker.ietf.org/doc/draft-ietf-tls-keylogfile/. // Use of KeyLogWriter compromises security and should only be // used for debugging. KeyLogWriter io.Writer @@ -910,13 +942,29 @@ type Config struct { // with a specific ECH config known to a client. type EncryptedClientHelloKey struct { // Config should be a marshalled ECHConfig associated with PrivateKey. This - // must match the config provided to clients byte-for-byte. The config - // should only specify the DHKEM(X25519, HKDF-SHA256) KEM ID (0x0020), the - // HKDF-SHA256 KDF ID (0x0001), and a subset of the following AEAD IDs: - // AES-128-GCM (0x0001), AES-256-GCM (0x0002), ChaCha20Poly1305 (0x0003). + // must match the config provided to clients byte-for-byte. The config must + // use as KEM one of + // + // - DHKEM(P-256, HKDF-SHA256) (0x0010) + // - DHKEM(P-384, HKDF-SHA384) (0x0011) + // - DHKEM(P-521, HKDF-SHA512) (0x0012) + // - DHKEM(X25519, HKDF-SHA256) (0x0020) + // + // and as KDF one of + // + // - HKDF-SHA256 (0x0001) + // - HKDF-SHA384 (0x0002) + // - HKDF-SHA512 (0x0003) + // + // and as AEAD one of + // + // - AES-128-GCM (0x0001) + // - AES-256-GCM (0x0002) + // - ChaCha20Poly1305 (0x0003) + // Config []byte - // PrivateKey should be a marshalled private key. Currently, we expect - // this to be the output of [ecdh.PrivateKey.Bytes]. + // PrivateKey should be a marshalled private key, in the format expected by + // HPKE's DeserializePrivateKey (see RFC 9180), for the KEM used in Config. PrivateKey []byte // SendAsRetry indicates if Config should be sent as part of the list of // retry configs when ECH is requested by the client but rejected by the @@ -961,8 +1009,15 @@ func (c *Config) ticketKeyFromBytes(b [32]byte) (key ticketKey) { // ticket, and the lifetime we set for all tickets we send. const maxSessionTicketLifetime = 7 * 24 * time.Hour -// Clone returns a shallow clone of c or nil if c is nil. It is safe to clone a [Config] that is -// being used concurrently by a TLS client or server. +// Clone returns a shallow clone of c or nil if c is nil. It is safe to clone a +// [Config] that is being used concurrently by a TLS client or server. +// +// The returned Config can share session ticket keys with the original Config, +// which means connections could be resumed across the two Configs. WARNING: +// [Config.VerifyPeerCertificate] does not get called on resumed connections, +// including connections that were originally established on the parent Config. +// If that is not intended, use [Config.VerifyConnection] instead, or set +// [Config.SessionTicketsDisabled]. func (c *Config) Clone() *Config { if c == nil { return nil @@ -1589,9 +1644,14 @@ var writerMutex sync.Mutex type Certificate struct { Certificate [][]byte // PrivateKey contains the private key corresponding to the public key in - // Leaf. This must implement crypto.Signer with an RSA, ECDSA or Ed25519 PublicKey. + // Leaf. This must implement [crypto.Signer] with an RSA, ECDSA or Ed25519 + // PublicKey. + // // For a server up to TLS 1.2, it can also implement crypto.Decrypter with // an RSA PublicKey. + // + // If it implements [crypto.MessageSigner], SignMessage will be used instead + // of Sign for TLS 1.2 and later. PrivateKey crypto.PrivateKey // SupportedSignatureAlgorithms is an optional list restricting what // signature algorithms the PrivateKey can be used for. @@ -1718,35 +1778,62 @@ func unexpectedMessageError(wanted, got any) error { return fmt.Errorf("tls: received unexpected handshake message of type %T when waiting for %T", got, wanted) } +var testingOnlySupportedSignatureAlgorithms []SignatureScheme + // supportedSignatureAlgorithms returns the supported signature algorithms for // the given minimum TLS version, to advertise in ClientHello and // CertificateRequest messages. func supportedSignatureAlgorithms(minVers uint16) []SignatureScheme { sigAlgs := defaultSupportedSignatureAlgorithms() - if fips140tls.Required() { - sigAlgs = slices.DeleteFunc(sigAlgs, func(s SignatureScheme) bool { - return !slices.Contains(allowedSignatureAlgorithmsFIPS, s) - }) + if testingOnlySupportedSignatureAlgorithms != nil { + sigAlgs = slices.Clone(testingOnlySupportedSignatureAlgorithms) } - if minVers > VersionTLS12 { - sigAlgs = slices.DeleteFunc(sigAlgs, func(s SignatureScheme) bool { - sigType, sigHash, _ := typeAndHashFromSignatureScheme(s) - return sigType == signaturePKCS1v15 || sigHash == crypto.SHA1 - }) + return slices.DeleteFunc(sigAlgs, func(s SignatureScheme) bool { + return isDisabledSignatureAlgorithm(minVers, s, false) + }) +} + +//var tlssha1 = godebug.New("tlssha1") + +func isDisabledSignatureAlgorithm(version uint16, s SignatureScheme, isCert bool) bool { + if fips140tls.Required() && !slices.Contains(allowedSignatureAlgorithmsFIPS, s) { + return true + } + + // For the _cert extension we include all algorithms, including SHA-1 and + // PKCS#1 v1.5, because it's more likely that something on our side will be + // willing to accept a *-with-SHA1 certificate (e.g. with a custom + // VerifyConnection or by a direct match with the CertPool), than that the + // peer would have a better certificate but is just choosing not to send it. + // crypto/x509 will refuse to verify important SHA-1 signatures anyway. + if isCert { + return false + } + + // TLS 1.3 removed support for PKCS#1 v1.5 and SHA-1 signatures, + // and Go 1.25 removed support for SHA-1 signatures in TLS 1.2. + if version > VersionTLS12 { + sigType, sigHash, _ := typeAndHashFromSignatureScheme(s) + if sigType == signaturePKCS1v15 || sigHash == crypto.SHA1 { + return true + } + } else { //if tlssha1.Value() != "1" { + _, sigHash, _ := typeAndHashFromSignatureScheme(s) + if sigHash == crypto.SHA1 { + return true + } } - return sigAlgs + + return false } // supportedSignatureAlgorithmsCert returns the supported algorithms for // signatures in certificates. func supportedSignatureAlgorithmsCert() []SignatureScheme { - sigAlgs := defaultSupportedSignatureAlgorithmsCert() - if fips140tls.Required() { - sigAlgs = slices.DeleteFunc(sigAlgs, func(s SignatureScheme) bool { - return !slices.Contains(allowedSignatureAlgorithmsFIPS, s) - }) - } - return sigAlgs + sigAlgs := defaultSupportedSignatureAlgorithms() + return slices.DeleteFunc(sigAlgs, func(s SignatureScheme) bool { + return isDisabledSignatureAlgorithm(0, s, true) + }) } func isSupportedSignatureAlgorithm(sigAlg SignatureScheme, supportedSignatureAlgorithms []SignatureScheme) bool { @@ -1806,3 +1893,43 @@ func fipsAllowChain(chain []*x509.Certificate) bool { return true } + +// anyValidVerifiedChain reports if at least one of the chains in verifiedChains +// is valid, as indicated by none of the certificates being expired and the root +// being in opts.Roots (or in the system root pool if opts.Roots is nil). If +// verifiedChains is empty, it returns false. +func anyValidVerifiedChain(verifiedChains [][]*x509.Certificate, opts x509.VerifyOptions) bool { + for _, chain := range verifiedChains { + if len(chain) == 0 { + continue + } + if slices.ContainsFunc(chain, func(cert *x509.Certificate) bool { + return opts.CurrentTime.Before(cert.NotBefore) || opts.CurrentTime.After(cert.NotAfter) + }) { + continue + } + // Since we already validated the chain, we only care that it is rooted + // in a CA in opts.Roots. On platforms where we control chain validation + // (e.g. not Windows or macOS) this is a simple lookup in the CertPool + // internal hash map, which we can simulate by running Verify on the + // root. On other platforms, we have to do full verification again, + // because EKU handling might differ. We will want to replace this with + // CertPool.Contains if/once that is available. See go.dev/issue/77376. + if runtime.GOOS == "windows" || runtime.GOOS == "darwin" || runtime.GOOS == "ios" { + opts.Intermediates = x509.NewCertPool() + for _, cert := range chain[1:max(1, len(chain)-1)] { + opts.Intermediates.AddCert(cert) + } + leaf := chain[0] + if _, err := leaf.Verify(opts); err == nil { + return true + } + } else { + root := chain[len(chain)-1] + if _, err := root.Verify(opts); err == nil { + return true + } + } + } + return false +} diff --git a/common_string.go b/common_string.go index b644d35..6525e8e 100644 --- a/common_string.go +++ b/common_string.go @@ -72,16 +72,19 @@ func _() { _ = x[CurveP521-25] _ = x[X25519-29] _ = x[X25519MLKEM768-4588] + _ = x[SecP256r1MLKEM768-4587] + _ = x[SecP384r1MLKEM1024-4589] } const ( _CurveID_name_0 = "CurveP256CurveP384CurveP521" _CurveID_name_1 = "X25519" - _CurveID_name_2 = "X25519MLKEM768" + _CurveID_name_2 = "SecP256r1MLKEM768X25519MLKEM768SecP384r1MLKEM1024" ) var ( _CurveID_index_0 = [...]uint8{0, 9, 18, 27} + _CurveID_index_2 = [...]uint8{0, 17, 31, 49} ) func (i CurveID) String() string { @@ -91,8 +94,9 @@ func (i CurveID) String() string { return _CurveID_name_0[_CurveID_index_0[i]:_CurveID_index_0[i+1]] case i == 29: return _CurveID_name_1 - case i == 4588: - return _CurveID_name_2 + case 4587 <= i && i <= 4589: + i -= 4587 + return _CurveID_name_2[_CurveID_index_2[i]:_CurveID_index_2[i+1]] default: return "CurveID(" + strconv.FormatInt(int64(i), 10) + ")" } diff --git a/conn.go b/conn.go index d99b2e6..3bcca8b 100644 --- a/conn.go +++ b/conn.go @@ -56,6 +56,7 @@ type Conn struct { didHRR bool // whether a HelloRetryRequest was sent/received cipherSuite uint16 curveID CurveID + peerSigAlg SignatureScheme ocspResponse []byte // stapled OCSP response scts [][]byte // signed certificate timestamps from server peerCertificates []*x509.Certificate @@ -227,20 +228,19 @@ func (hc *halfConn) changeCipherSpec() error { hc.mac = hc.nextMac hc.nextCipher = nil hc.nextMac = nil - for i := range hc.seq { - hc.seq[i] = 0 - } + clear(hc.seq[:]) return nil } +// setTrafficSecret sets the traffic secret for the given encryption level. setTrafficSecret +// should not be called directly, but rather through the Conn setWriteTrafficSecret and +// setReadTrafficSecret wrapper methods. func (hc *halfConn) setTrafficSecret(suite *cipherSuiteTLS13, level QUICEncryptionLevel, secret []byte) { hc.trafficSecret = secret hc.level = level key, iv := suite.trafficKey(secret) hc.cipher = suite.aead(key, iv) - for i := range hc.seq { - hc.seq[i] = 0 - } + clear(hc.seq[:]) } // incSeq increments the sequence number. @@ -838,29 +838,6 @@ func (c *Conn) retryReadRecord(expectChangeCipherSpec bool) error { return c.readRecordOrCCS(expectChangeCipherSpec) } -// atLeastReader reads from R, stopping with EOF once at least N bytes have been -// read. It is different from an io.LimitedReader in that it doesn't cut short -// the last Read call, and in that it considers an early EOF an error. -type atLeastReader struct { - R io.Reader - N int64 -} - -func (r *atLeastReader) Read(p []byte) (int, error) { - if r.N <= 0 { - return 0, io.EOF - } - n, err := r.R.Read(p) - r.N -= int64(n) // won't underflow unless len(p) >= n > 9223372036854775809 - if r.N > 0 && err == io.EOF { - return n, io.ErrUnexpectedEOF - } - if r.N <= 0 && err == nil { - return n, io.EOF - } - return n, err -} - // readFromUntil reads from r into c.rawInput until c.rawInput contains // at least n bytes or else returns an error. func (c *Conn) readFromUntil(r io.Reader, n int) error { @@ -871,9 +848,31 @@ func (c *Conn) readFromUntil(r io.Reader, n int) error { // There might be extra input waiting on the wire. Make a best effort // attempt to fetch it so that it can be used in (*Conn).Read to // "predict" closeNotify alerts. + // TODO(dmo): we use bytes.MinRead here because we used the buffer + // ReadFrom mechanism to avoid allocations, but we've hoisted this + // loop for performance. We really should use our own heuristic here + // for how much to read ahead. c.rawInput.Grow(needs + bytes.MinRead) - _, err := c.rawInput.ReadFrom(&atLeastReader{r, int64(needs)}) - return err + for { + buf := c.rawInput.AvailableBuffer()[:c.rawInput.Available()] + n, err := r.Read(buf) + // This write is just to update the internal state of the + // rawInput bytes.Buffer. It cannot fail. + c.rawInput.Write(buf[:n]) + needs -= n + if needs <= 0 { + if err == io.EOF { + err = nil + } + return err + } + if err == io.EOF { + return io.ErrUnexpectedEOF + } + if err != nil { + return err + } + } } // sendAlertLocked sends a TLS alert message. @@ -1409,9 +1408,6 @@ func (c *Conn) handleKeyUpdate(keyUpdate *keyUpdateMsg) error { return c.in.setErrorLocked(c.sendAlert(alertInternalError)) } - newSecret := cipherSuite.nextTrafficSecret(c.in.trafficSecret) - c.in.setTrafficSecret(cipherSuite, QUICEncryptionLevelInitial, newSecret) - if keyUpdate.updateRequested { c.out.Lock() defer c.out.Unlock() @@ -1429,7 +1425,12 @@ func (c *Conn) handleKeyUpdate(keyUpdate *keyUpdateMsg) error { } newSecret := cipherSuite.nextTrafficSecret(c.out.trafficSecret) - c.out.setTrafficSecret(cipherSuite, QUICEncryptionLevelInitial, newSecret) + c.setWriteTrafficSecret(cipherSuite, QUICEncryptionLevelInitial, newSecret) + } + + newSecret := cipherSuite.nextTrafficSecret(c.in.trafficSecret) + if err := c.setReadTrafficSecret(cipherSuite, QUICEncryptionLevelInitial, newSecret, keyUpdate.updateRequested); err != nil { + return err } return nil @@ -1590,37 +1591,23 @@ func (c *Conn) handshakeContext(ctx context.Context) (ret error) { } handshakeCtx, cancel := context.WithCancel(ctx) - // Note: defer this before starting the "interrupter" goroutine + // Note: defer this before calling context.AfterFunc // so that we can tell the difference between the input being canceled and // this cancellation. In the former case, we need to close the connection. defer cancel() if c.quic != nil { - c.quic.cancelc = handshakeCtx.Done() + c.quic.ctx = handshakeCtx c.quic.cancel = cancel } else if ctx.Done() != nil { - // Start the "interrupter" goroutine, if this context might be canceled. - // (The background context cannot). - // - // The interrupter goroutine waits for the input context to be done and - // closes the connection if this happens before the function returns. - done := make(chan struct{}) - interruptRes := make(chan error, 1) + // Close the connection if ctx is canceled before the function returns. + stop := context.AfterFunc(ctx, func() { + _ = c.conn.Close() + }) defer func() { - close(done) - if ctxErr := <-interruptRes; ctxErr != nil { + if !stop() { // Return context error to user. - ret = ctxErr - } - }() - go func() { - select { - case <-handshakeCtx.Done(): - // Close the connection, discarding the error - _ = c.conn.Close() - interruptRes <- handshakeCtx.Err() - case <-done: - interruptRes <- nil + ret = ctx.Err() } }() } @@ -1660,11 +1647,13 @@ func (c *Conn) handshakeContext(ctx context.Context) (ret error) { // Provide the 1-RTT read secret now that the handshake is complete. // The QUIC layer MUST NOT decrypt 1-RTT packets prior to completing // the handshake (RFC 9001, Section 5.7). - c.quicSetReadSecret(QUICEncryptionLevelApplication, c.cipherSuite, c.in.trafficSecret) + if err := c.quicSetReadSecret(QUICEncryptionLevelApplication, c.cipherSuite, c.in.trafficSecret); err != nil { + return err + } } else { - var a alert c.out.Lock() - if !errors.As(c.out.err, &a) { + a, ok := errors.AsType[alert](c.out.err) + if !ok { a = alertInternalError } c.out.Unlock() @@ -1694,7 +1683,8 @@ func (c *Conn) connectionStateLocked() ConnectionState { state.Version = c.vers state.NegotiatedProtocol = c.clientProtocol state.DidResume = c.didResume - state.testingOnlyDidHRR = c.didHRR + state.HelloRetryRequest = c.didHRR + state.testingOnlyPeerSignatureAlgorithm = c.peerSigAlg state.CurveID = c.curveID state.NegotiatedProtocolIsMutual = true state.ServerName = c.serverName @@ -1753,3 +1743,29 @@ func (c *Conn) VerifyHostname(host string) error { } return c.peerCertificates[0].VerifyHostname(host) } + +// setReadTrafficSecret sets the read traffic secret for the given encryption level. If +// being called at the same time as setWriteTrafficSecret, the caller must ensure the call +// to setWriteTrafficSecret happens first so any alerts are sent at the write level. +func (c *Conn) setReadTrafficSecret(suite *cipherSuiteTLS13, level QUICEncryptionLevel, secret []byte, locked bool) error { + // Ensure that there are no buffered handshake messages before changing the + // read keys, since that can cause messages to be parsed that were encrypted + // using old keys which are no longer appropriate. + if c.hand.Len() != 0 { + if locked { + c.sendAlertLocked(alertUnexpectedMessage) + } else { + c.sendAlert(alertUnexpectedMessage) + } + return errors.New("tls: handshake buffer not empty before setting read traffic secret") + } + c.in.setTrafficSecret(suite, level, secret) + return nil +} + +// setWriteTrafficSecret sets the write traffic secret for the given encryption level. If +// being called at the same time as setReadTrafficSecret, the caller must ensure the call +// to setWriteTrafficSecret happens first so any alerts are sent at the write level. +func (c *Conn) setWriteTrafficSecret(suite *cipherSuiteTLS13, level QUICEncryptionLevel, secret []byte) { + c.out.setTrafficSecret(suite, level, secret) +} diff --git a/defaults.go b/defaults.go index 617235b..bf24450 100644 --- a/defaults.go +++ b/defaults.go @@ -13,47 +13,31 @@ import ( // them to apply local policies. //var tlsmlkem = godebug.New("tlsmlkem") +//var tlssecpmlkem = godebug.New("tlssecpmlkem") // defaultCurvePreferences is the default set of supported key exchanges, as // well as the preference order. func defaultCurvePreferences() []CurveID { - if false { - return []CurveID{X25519, CurveP256, CurveP384, CurveP521} + switch { + // // tlsmlkem=0 restores the pre-Go 1.24 default. + // case tlsmlkem.Value() == "0": + // return []CurveID{X25519, CurveP256, CurveP384, CurveP521} + // // tlssecpmlkem=0 restores the pre-Go 1.26 default. + // case tlssecpmlkem.Value() == "0": + // return []CurveID{X25519MLKEM768, X25519, CurveP256, CurveP384, CurveP521} + default: + return []CurveID{ + X25519MLKEM768, SecP256r1MLKEM768, SecP384r1MLKEM1024, + X25519, CurveP256, CurveP384, CurveP521, + } } - return []CurveID{X25519MLKEM768, X25519, CurveP256, CurveP384, CurveP521} } -//var tlssha1 = godebug.New("tlssha1") - // defaultSupportedSignatureAlgorithms returns the signature and hash algorithms that // the code advertises and supports in a TLS 1.2+ ClientHello and in a TLS 1.2+ // CertificateRequest. The two fields are merged to match with TLS 1.3. // Note that in TLS 1.2, the ECDSA algorithms are not constrained to P-256, etc. func defaultSupportedSignatureAlgorithms() []SignatureScheme { - return []SignatureScheme{ - PSSWithSHA256, - ECDSAWithP256AndSHA256, - Ed25519, - PSSWithSHA384, - PSSWithSHA512, - PKCS1WithSHA256, - PKCS1WithSHA384, - PKCS1WithSHA512, - ECDSAWithP384AndSHA384, - ECDSAWithP521AndSHA512, - } -} - -// defaultSupportedSignatureAlgorithmsCert returns the signature algorithms that -// the code advertises as supported for signatures in certificates. -// -// We include all algorithms, including SHA-1 and PKCS#1 v1.5, because it's more -// likely that something on our side will be willing to accept a *-with-SHA1 -// certificate (e.g. with a custom VerifyConnection or by a direct match with -// the CertPool), than that the peer would have a better certificate but is just -// choosing not to send it. crypto/x509 will refuse to verify important SHA-1 -// signatures anyway. -func defaultSupportedSignatureAlgorithmsCert() []SignatureScheme { return []SignatureScheme{ PSSWithSHA256, ECDSAWithP256AndSHA256, diff --git a/defaults_fips140.go b/defaults_fips140.go index be0ba95..2777a79 100644 --- a/defaults_fips140.go +++ b/defaults_fips140.go @@ -32,6 +32,8 @@ var ( } allowedCurvePreferencesFIPS = []CurveID{ X25519MLKEM768, + SecP256r1MLKEM768, + SecP384r1MLKEM1024, CurveP256, CurveP384, CurveP521, diff --git a/ech.go b/ech.go index 3d9226b..f86752e 100644 --- a/ech.go +++ b/ech.go @@ -6,28 +6,14 @@ package reality import ( "bytes" + "crypto/hpke" "errors" "fmt" - "slices" "strings" "golang.org/x/crypto/cryptobyte" - - "github.com/xtls/reality/hpke" ) -// sortedSupportedAEADs is just a sorted version of hpke.SupportedAEADS. -// We need this so that when we insert them into ECHConfigs the ordering -// is stable. -var sortedSupportedAEADs []uint16 - -func init() { - for aeadID := range hpke.SupportedAEADs { - sortedSupportedAEADs = append(sortedSupportedAEADs, aeadID) - } - slices.Sort(sortedSupportedAEADs) -} - type EchCipher struct { KDFID uint16 AEADID uint16 @@ -163,25 +149,8 @@ func parseECHConfigList(data []byte) ([]EchConfig, error) { return configs, nil } -func pickECHConfig(list []EchConfig) *EchConfig { +func pickECHConfig(list []EchConfig) (*EchConfig, hpke.PublicKey, hpke.KDF, hpke.AEAD) { for _, ec := range list { - if _, ok := hpke.SupportedKEMs[ec.KemID]; !ok { - continue - } - var validSCS bool - for _, cs := range ec.SymmetricCipherSuite { - if _, ok := hpke.SupportedAEADs[cs.AEADID]; !ok { - continue - } - if _, ok := hpke.SupportedKDFs[cs.KDFID]; !ok { - continue - } - validSCS = true - break - } - if !validSCS { - continue - } if !validDNSName(string(ec.PublicName)) { continue } @@ -197,25 +166,37 @@ func pickECHConfig(list []EchConfig) *EchConfig { if unsupportedExt { continue } - return &ec - } - return nil -} - -func pickECHCipherSuite(suites []EchCipher) (EchCipher, error) { - for _, s := range suites { - // NOTE: all of the supported AEADs and KDFs are fine, rather than - // imposing some sort of preference here, we just pick the first valid - // suite. - if _, ok := hpke.SupportedAEADs[s.AEADID]; !ok { + kem, err := hpke.NewKEM(ec.KemID) + if err != nil { continue } - if _, ok := hpke.SupportedKDFs[s.KDFID]; !ok { + pub, err := kem.NewPublicKey(ec.PublicKey) + if err != nil { + // This is an error in the config, but killing the connection feels + // excessive. continue } - return s, nil + for _, cs := range ec.SymmetricCipherSuite { + // All of the supported AEADs and KDFs are fine, rather than + // imposing some sort of preference here, we just pick the first + // valid suite. + kdf, err := hpke.NewKDF(cs.KDFID) + if err != nil { + continue + } + // 0xFFFF is an export-only AEAD that cannot seal/open, making + // it an invalid choice for encrypting ClientHelloInner. + if cs.AEADID == 0xFFFF { + continue + } + aead, err := hpke.NewAEAD(cs.AEADID) + if err != nil { + continue + } + return &ec, pub, kdf, aead + } } - return EchCipher{}, errors.New("tls: no supported symmetric ciphersuites for ECH") + return nil, nil, nil, nil } func encodeInnerClientHello(inner *clientHelloMsg, maxNameLength int) ([]byte, error) { @@ -231,7 +212,7 @@ func encodeInnerClientHello(inner *clientHelloMsg, maxNameLength int) ([]byte, e } else { paddingLen = maxNameLength + 9 } - paddingLen = 31 - ((len(h) + paddingLen - 1) % 32) + paddingLen += 31 - ((len(h) + paddingLen - 1) % 32) return append(h, make([]byte, paddingLen)...), nil } @@ -569,16 +550,6 @@ func parseECHExt(ext []byte) (echType echExtType, cs EchCipher, configID uint8, return echType, cs, configID, bytes.Clone(encap), bytes.Clone(payload), nil } -func marshalEncryptedClientHelloConfigList(configs []EncryptedClientHelloKey) ([]byte, error) { - builder := cryptobyte.NewBuilder(nil) - builder.AddUint16LengthPrefixed(func(builder *cryptobyte.Builder) { - for _, c := range configs { - builder.AddBytes(c.Config) - } - }) - return builder.Bytes() -} - func (c *Conn) processECHClientHello(outer *clientHelloMsg, echKeys []EncryptedClientHelloKey) (*clientHelloMsg, *echServerContext, error) { echType, echCiphersuite, configID, encap, payload, err := parseECHExt(outer.encryptedClientHello) if err != nil { @@ -601,20 +572,35 @@ func (c *Conn) processECHClientHello(outer *clientHelloMsg, echKeys []EncryptedC for _, echKey := range echKeys { skip, config, err := parseECHConfig(echKey.Config) - if err != nil || skip { + if err != nil { c.sendAlert(alertInternalError) - return nil, nil, fmt.Errorf("tls: invalid EncryptedClientHelloKeys Config: %s", err) + return nil, nil, fmt.Errorf("tls: invalid EncryptedClientHelloKey Config: %s", err) } if skip { continue } - echPriv, err := hpke.ParseHPKEPrivateKey(config.KemID, echKey.PrivateKey) + kem, err := hpke.NewKEM(config.KemID) + if err != nil { + c.sendAlert(alertInternalError) + return nil, nil, fmt.Errorf("tls: invalid EncryptedClientHelloKey Config KEM: %s", err) + } + echPriv, err := kem.NewPrivateKey(echKey.PrivateKey) + if err != nil { + c.sendAlert(alertInternalError) + return nil, nil, fmt.Errorf("tls: invalid EncryptedClientHelloKey PrivateKey: %s", err) + } + kdf, err := hpke.NewKDF(echCiphersuite.KDFID) + if err != nil { + c.sendAlert(alertInternalError) + return nil, nil, fmt.Errorf("tls: invalid EncryptedClientHelloKey Config KDF: %s", err) + } + aead, err := hpke.NewAEAD(echCiphersuite.AEADID) if err != nil { c.sendAlert(alertInternalError) - return nil, nil, fmt.Errorf("tls: invalid EncryptedClientHelloKeys PrivateKey: %s", err) + return nil, nil, fmt.Errorf("tls: invalid EncryptedClientHelloKey Config AEAD: %s", err) } info := append([]byte("tls ech\x00"), echKey.Config...) - hpkeContext, err := hpke.SetupRecipient(hpke.DHKEM_X25519_HKDF_SHA256, echCiphersuite.KDFID, echCiphersuite.AEADID, echPriv, info, encap) + hpkeContext, err := hpke.NewRecipient(encap, echPriv, kdf, aead, info) if err != nil { // attempt next trial decryption continue diff --git a/go.mod b/go.mod index 22f8f90..5268cb5 100644 --- a/go.mod +++ b/go.mod @@ -1,18 +1,18 @@ module github.com/xtls/reality -go 1.24.0 +go 1.26 require ( github.com/cloudflare/circl v1.6.3 github.com/juju/ratelimit v1.0.2 - github.com/pires/go-proxyproto v0.11.0 + github.com/pires/go-proxyproto v0.12.0 github.com/refraction-networking/utls v1.8.2 - golang.org/x/crypto v0.48.0 - golang.org/x/sys v0.41.0 + golang.org/x/crypto v0.51.0 + golang.org/x/sys v0.44.0 ) require ( - github.com/andybalholm/brotli v1.0.6 // indirect - github.com/klauspost/compress v1.17.4 // indirect + github.com/andybalholm/brotli v1.2.1 // indirect + github.com/klauspost/compress v1.18.6 // indirect gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect ) diff --git a/go.sum b/go.sum index 72fddb4..cf5d012 100644 --- a/go.sum +++ b/go.sum @@ -1,23 +1,25 @@ -github.com/andybalholm/brotli v1.0.6 h1:Yf9fFpf49Zrxb9NlQaluyE92/+X7UVHlhMNJN2sxfOI= -github.com/andybalholm/brotli v1.0.6/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig= +github.com/andybalholm/brotli v1.2.1 h1:R+f5xP285VArJDRgowrfb9DqL18yVK0gKAW/F+eTWro= +github.com/andybalholm/brotli v1.2.1/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY= github.com/cloudflare/circl v1.6.3 h1:9GPOhQGF9MCYUeXyMYlqTR6a5gTrgR/fBLXvUgtVcg8= github.com/cloudflare/circl v1.6.3/go.mod h1:2eXP6Qfat4O/Yhh8BznvKnJ+uzEoTQ6jVKJRn81BiS4= github.com/juju/ratelimit v1.0.2 h1:sRxmtRiajbvrcLQT7S+JbqU0ntsb9W2yhSdNN8tWfaI= github.com/juju/ratelimit v1.0.2/go.mod h1:qapgC/Gy+xNh9UxzV13HGGl/6UXNN+ct+vwSgWNm/qk= -github.com/klauspost/compress v1.17.4 h1:Ej5ixsIri7BrIjBkRZLTo6ghwrEtHFk7ijlczPW4fZ4= -github.com/klauspost/compress v1.17.4/go.mod h1:/dCuZOvVtNoHsyb+cuJD3itjs3NbnF6KH9zAO4BDxPM= +github.com/klauspost/compress v1.18.6 h1:2jupLlAwFm95+YDR+NwD2MEfFO9d4z4Prjl1XXDjuao= +github.com/klauspost/compress v1.18.6/go.mod h1:cwPg85FWrGar70rWktvGQj8/hthj3wpl0PGDogxkrSQ= github.com/kr/pretty v0.2.1 h1:Fmg33tUaq4/8ym9TJN1x7sLJnHVwhP33CNkpYV/7rwI= github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= -github.com/pires/go-proxyproto v0.11.0 h1:gUQpS85X/VJMdUsYyEgyn59uLJvGqPhJV5YvG68wXH4= -github.com/pires/go-proxyproto v0.11.0/go.mod h1:ZKAAyp3cgy5Y5Mo4n9AlScrkCZwUy0g3Jf+slqQVcuU= +github.com/pires/go-proxyproto v0.12.0 h1:TTCxD66dU898tahivkqc3hoceZp7P44FnorWyo9d5vM= +github.com/pires/go-proxyproto v0.12.0/go.mod h1:qUvfqUMEoX7T8g0q7TQLDnhMjdTrxnG0hvpMn+7ePNI= github.com/refraction-networking/utls v1.8.2 h1:j4Q1gJj0xngdeH+Ox/qND11aEfhpgoEvV+S9iJ2IdQo= github.com/refraction-networking/utls v1.8.2/go.mod h1:jkSOEkLqn+S/jtpEHPOsVv/4V4EVnelwbMQl4vCWXAM= -golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts= -golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos= -golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k= -golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU= +github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E= +golang.org/x/crypto v0.51.0 h1:IBPXwPfKxY7cWQZ38ZCIRPI50YLeevDLlLnyC5wRGTI= +golang.org/x/crypto v0.51.0/go.mod h1:8AdwkbraGNABw2kOX6YFPs3WM22XqI4EXEd8g+x7Oc8= +golang.org/x/sys v0.44.0 h1:ildZl3J4uzeKP07r2F++Op7E9B29JRUy+a27EibtBTQ= +golang.org/x/sys v0.44.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= diff --git a/handshake_client.go b/handshake_client.go index 029a167..c05b646 100644 --- a/handshake_client.go +++ b/handshake_client.go @@ -10,11 +10,10 @@ import ( "crypto" "crypto/ecdsa" "crypto/ed25519" - "crypto/mlkem" + "crypto/hpke" "crypto/rsa" "crypto/subtle" "crypto/x509" - "encoding/binary" "errors" "fmt" "hash" @@ -25,7 +24,6 @@ import ( "time" "github.com/xtls/reality/fips140tls" - "github.com/xtls/reality/hpke" "github.com/xtls/reality/tls13" ) @@ -41,8 +39,6 @@ type clientHandshakeState struct { ticket []byte // a fresh ticket received during this handshake } -var testingOnlyForceClientHelloSignatureAlgorithms []SignatureScheme - func (c *Conn) makeClientHello() (*clientHelloMsg, *keySharePrivateKeys, *echClientContext, error) { config := c.config if len(config.ServerName) == 0 && !config.InsecureSkipVerify { @@ -126,9 +122,6 @@ func (c *Conn) makeClientHello() (*clientHelloMsg, *keySharePrivateKeys, *echCli hello.supportedSignatureAlgorithms = supportedSignatureAlgorithms(minVersion) hello.supportedSignatureAlgorithmsCert = supportedSignatureAlgorithmsCert() } - if testingOnlyForceClientHelloSignatureAlgorithms != nil { - hello.supportedSignatureAlgorithms = testingOnlyForceClientHelloSignatureAlgorithms - } var keyShareKeys *keySharePrivateKeys if maxVersion >= VersionTLS13 { @@ -148,43 +141,21 @@ func (c *Conn) makeClientHello() (*clientHelloMsg, *keySharePrivateKeys, *echCli if len(hello.supportedCurves) == 0 { return nil, nil, nil, errors.New("tls: no supported elliptic curves for ECDHE") } + // Since the order is fixed, the first one is always the one to send a + // key share for. All the PQ hybrids sort first, and produce a fallback + // ECDH share. curveID := hello.supportedCurves[0] - keyShareKeys = &keySharePrivateKeys{curveID: curveID} - // Note that if X25519MLKEM768 is supported, it will be first because - // the preference order is fixed. - if curveID == X25519MLKEM768 { - keyShareKeys.ecdhe, err = generateECDHEKey(config.rand(), X25519) - if err != nil { - return nil, nil, nil, err - } - seed := make([]byte, mlkem.SeedSize) - if _, err := io.ReadFull(config.rand(), seed); err != nil { - return nil, nil, nil, err - } - keyShareKeys.mlkem, err = mlkem.NewDecapsulationKey768(seed) - if err != nil { - return nil, nil, nil, err - } - mlkemEncapsulationKey := keyShareKeys.mlkem.EncapsulationKey().Bytes() - x25519EphemeralKey := keyShareKeys.ecdhe.PublicKey().Bytes() - hello.keyShares = []keyShare{ - {group: X25519MLKEM768, data: append(mlkemEncapsulationKey, x25519EphemeralKey...)}, - } - // If both X25519MLKEM768 and X25519 are supported, we send both key - // shares (as a fallback) and we reuse the same X25519 ephemeral - // key, as allowed by draft-ietf-tls-hybrid-design-09, Section 3.2. - if slices.Contains(hello.supportedCurves, X25519) { - hello.keyShares = append(hello.keyShares, keyShare{group: X25519, data: x25519EphemeralKey}) - } - } else { - if _, ok := curveForCurveID(curveID); !ok { - return nil, nil, nil, errors.New("tls: CurvePreferences includes unsupported curve") - } - keyShareKeys.ecdhe, err = generateECDHEKey(config.rand(), curveID) - if err != nil { - return nil, nil, nil, err - } - hello.keyShares = []keyShare{{group: curveID, data: keyShareKeys.ecdhe.PublicKey().Bytes()}} + ke, err := keyExchangeForCurveID(curveID) + if err != nil { + return nil, nil, nil, errors.New("tls: CurvePreferences includes unsupported curve") + } + keyShareKeys, hello.keyShares, err = ke.keyShares(config.rand()) + if err != nil { + return nil, nil, nil, err + } + // Only send the fallback ECDH share if the corresponding CurveID is enabled. + if len(hello.keyShares) == 2 && !slices.Contains(hello.supportedCurves, hello.keyShares[1].group) { + hello.keyShares = hello.keyShares[:1] } } @@ -211,11 +182,11 @@ func (c *Conn) makeClientHello() (*clientHelloMsg, *keySharePrivateKeys, *echCli if err != nil { return nil, nil, nil, err } - echConfig := pickECHConfig(echConfigs) + echConfig, echPK, kdf, aead := pickECHConfig(echConfigs) if echConfig == nil { return nil, nil, nil, errors.New("tls: EncryptedClientHelloConfigList contains no valid configs") } - ech = &echClientContext{config: echConfig} + ech = &echClientContext{config: echConfig, kdfID: kdf.ID(), aeadID: aead.ID()} hello.encryptedClientHello = []byte{1} // indicate inner hello // We need to explicitly set these 1.2 fields to nil, as we do not // marshal them when encoding the inner hello, otherwise transcripts @@ -225,17 +196,8 @@ func (c *Conn) makeClientHello() (*clientHelloMsg, *keySharePrivateKeys, *echCli hello.secureRenegotiationSupported = false hello.extendedMasterSecret = false - echPK, err := hpke.ParseHPKEPublicKey(ech.config.KemID, ech.config.PublicKey) - if err != nil { - return nil, nil, nil, err - } - suite, err := pickECHCipherSuite(ech.config.SymmetricCipherSuite) - if err != nil { - return nil, nil, nil, err - } - ech.kdfID, ech.aeadID = suite.KDFID, suite.AEADID info := append([]byte("tls ech\x00"), ech.config.raw...) - ech.encapsulatedKey, ech.hpkeContext, err = hpke.SetupSender(ech.config.KemID, suite.KDFID, suite.AEADID, echPK, info) + ech.encapsulatedKey, ech.hpkeContext, err = hpke.NewSender(echPK, kdf, aead, info) if err != nil { return nil, nil, nil, err } @@ -324,7 +286,11 @@ func (c *Conn) clientHandshake(ctx context.Context) (err error) { if hello.earlyData { suite := cipherSuiteTLS13ByID(session.cipherSuite) transcript := suite.hash.New() - if err := transcriptMsg(hello, transcript); err != nil { + transcriptHello := hello + if ech != nil { + transcriptHello = ech.innerHello + } + if err := transcriptMsg(transcriptHello, transcript); err != nil { return err } earlyTrafficSecret := earlySecret.ClientEarlyTrafficSecret(transcript) @@ -433,9 +399,6 @@ func (c *Conn) loadSession(hello *clientHelloMsg) ( return nil, nil, nil, nil } - // Check that the cached server certificate is not expired, and that it's - // valid for the ServerName. This should be ensured by the cache key, but - // protect the application from a faulty ClientSessionCache implementation. if c.config.time().After(session.peerCertificates[0].NotAfter) { // Expired certificate, delete the entry. c.config.ClientSessionCache.Put(cacheKey, nil) @@ -447,6 +410,18 @@ func (c *Conn) loadSession(hello *clientHelloMsg) ( return nil, nil, nil, nil } if err := session.peerCertificates[0].VerifyHostname(c.config.ServerName); err != nil { + // This should be ensured by the cache key, but protect the + // application from a faulty ClientSessionCache implementation. + return nil, nil, nil, nil + } + opts := x509.VerifyOptions{ + CurrentTime: c.config.time(), + Roots: c.config.RootCAs, + KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + } + if !anyValidVerifiedChain(session.verifiedChains, opts) { + // No valid chains, delete the entry. + c.config.ClientSessionCache.Put(cacheKey, nil) return nil, nil, nil, nil } } @@ -725,8 +700,9 @@ func (hs *clientHandshakeState) doFullHandshake() error { c.sendAlert(alertIllegalParameter) return err } - if len(skx.key) >= 3 && skx.key[0] == 3 /* named curve */ { - c.curveID = CurveID(binary.BigEndian.Uint16(skx.key[1:])) + if keyAgreement, ok := keyAgreement.(*ecdheKeyAgreement); ok { + c.curveID = keyAgreement.curveID + c.peerSigAlg = keyAgreement.signatureAlgorithm } msg, err = c.readHandshake(&hs.finishedHash) @@ -807,37 +783,43 @@ func (hs *clientHandshakeState) doFullHandshake() error { return fmt.Errorf("tls: client certificate private key of type %T does not implement crypto.Signer", chainToSend.PrivateKey) } - var sigType uint8 - var sigHash crypto.Hash if c.vers >= VersionTLS12 { signatureAlgorithm, err := selectSignatureScheme(c.vers, chainToSend, certReq.supportedSignatureAlgorithms) if err != nil { - c.sendAlert(alertIllegalParameter) + c.sendAlert(alertHandshakeFailure) return err } - sigType, sigHash, err = typeAndHashFromSignatureScheme(signatureAlgorithm) + sigType, sigHash, err := typeAndHashFromSignatureScheme(signatureAlgorithm) if err != nil { return c.sendAlert(alertInternalError) } certVerify.hasSignatureAlgorithm = true certVerify.signatureAlgorithm = signatureAlgorithm + if hs.finishedHash.buffer == nil { + c.sendAlert(alertInternalError) + return errors.New("tls: internal error: did not keep handshake transcript for TLS 1.2") + } + signOpts := crypto.SignerOpts(sigHash) + if sigType == signatureRSAPSS { + signOpts = &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: sigHash} + } + certVerify.signature, err = crypto.SignMessage(key, c.config.rand(), hs.finishedHash.buffer, signOpts) + if err != nil { + c.sendAlert(alertInternalError) + return err + } } else { - sigType, sigHash, err = legacyTypeAndHashFromPublicKey(key.Public()) + sigType, sigHash, err := legacyTypeAndHashFromPublicKey(key.Public()) if err != nil { c.sendAlert(alertIllegalParameter) return err } - } - - signed := hs.finishedHash.hashForClientCertificate(sigType, sigHash) - signOpts := crypto.SignerOpts(sigHash) - if sigType == signatureRSAPSS { - signOpts = &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: sigHash} - } - certVerify.signature, err = key.Sign(c.config.rand(), signed, signOpts) - if err != nil { - c.sendAlert(alertInternalError) - return err + signed := hs.finishedHash.hashForClientCertificate(sigType) + certVerify.signature, err = key.Sign(c.config.rand(), signed, sigHash) + if err != nil { + c.sendAlert(alertInternalError) + return err + } } if _, err := hs.c.writeHandshakeRecord(certVerify, &hs.finishedHash); err != nil { diff --git a/handshake_client_tls13.go b/handshake_client_tls13.go index 186c6e7..b30f992 100644 --- a/handshake_client_tls13.go +++ b/handshake_client_tls13.go @@ -10,7 +10,6 @@ import ( "crypto" "crypto/hkdf" "crypto/hmac" - "crypto/mlkem" "crypto/rsa" "crypto/subtle" "errors" @@ -320,22 +319,18 @@ func (hs *clientHandshakeStateTLS13) processHelloRetryRequest() error { c.sendAlert(alertIllegalParameter) return errors.New("tls: server sent an unnecessary HelloRetryRequest key_share") } - // Note: we don't support selecting X25519MLKEM768 in a HRR, because it - // is currently first in preference order, so if it's enabled we'll - // always send a key share for it. - // - // This will have to change once we support multiple hybrid KEMs. - if _, ok := curveForCurveID(curveID); !ok { + ke, err := keyExchangeForCurveID(curveID) + if err != nil { c.sendAlert(alertInternalError) return errors.New("tls: CurvePreferences includes unsupported curve") } - key, err := generateECDHEKey(c.config.rand(), curveID) + hs.keyShareKeys, hello.keyShares, err = ke.keyShares(c.config.rand()) if err != nil { c.sendAlert(alertInternalError) return err } - hs.keyShareKeys = &keySharePrivateKeys{curveID: curveID, ecdhe: key} - hello.keyShares = []keyShare{{group: curveID, data: key.PublicKey().Bytes()}} + // Do not send the fallback ECDH key share in a HRR response. + hello.keyShares = hello.keyShares[:1] } if len(hello.pskIdentities) > 0 { @@ -475,36 +470,16 @@ func (hs *clientHandshakeStateTLS13) processServerHello() error { func (hs *clientHandshakeStateTLS13) establishHandshakeKeys() error { c := hs.c - ecdhePeerData := hs.serverHello.serverShare.data - if hs.serverHello.serverShare.group == X25519MLKEM768 { - if len(ecdhePeerData) != mlkem.CiphertextSize768+x25519PublicKeySize { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: invalid server X25519MLKEM768 key share") - } - ecdhePeerData = hs.serverHello.serverShare.data[mlkem.CiphertextSize768:] - } - peerKey, err := hs.keyShareKeys.ecdhe.Curve().NewPublicKey(ecdhePeerData) + ke, err := keyExchangeForCurveID(hs.serverHello.serverShare.group) if err != nil { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: invalid server key share") + c.sendAlert(alertInternalError) + return err } - sharedKey, err := hs.keyShareKeys.ecdhe.ECDH(peerKey) + sharedKey, err := ke.clientSharedSecret(hs.keyShareKeys, hs.serverHello.serverShare.data) if err != nil { c.sendAlert(alertIllegalParameter) return errors.New("tls: invalid server key share") } - if hs.serverHello.serverShare.group == X25519MLKEM768 { - if hs.keyShareKeys.mlkem == nil { - return c.sendAlert(alertInternalError) - } - ciphertext := hs.serverHello.serverShare.data[:mlkem.CiphertextSize768] - mlkemShared, err := hs.keyShareKeys.mlkem.Decapsulate(ciphertext) - if err != nil { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: invalid X25519MLKEM768 server key share") - } - sharedKey = append(mlkemShared, sharedKey...) - } c.curveID = hs.serverHello.serverShare.group earlySecret := hs.earlySecret @@ -515,16 +490,17 @@ func (hs *clientHandshakeStateTLS13) establishHandshakeKeys() error { handshakeSecret := earlySecret.HandshakeSecret(sharedKey) clientSecret := handshakeSecret.ClientHandshakeTrafficSecret(hs.transcript) - c.out.setTrafficSecret(hs.suite, QUICEncryptionLevelHandshake, clientSecret) + c.setWriteTrafficSecret(hs.suite, QUICEncryptionLevelHandshake, clientSecret) serverSecret := handshakeSecret.ServerHandshakeTrafficSecret(hs.transcript) - c.in.setTrafficSecret(hs.suite, QUICEncryptionLevelHandshake, serverSecret) + if err := c.setReadTrafficSecret(hs.suite, QUICEncryptionLevelHandshake, serverSecret, false); err != nil { + return err + } if c.quic != nil { - if c.hand.Len() != 0 { - c.sendAlert(alertUnexpectedMessage) - } c.quicSetWriteSecret(QUICEncryptionLevelHandshake, hs.suite.id, clientSecret) - c.quicSetReadSecret(QUICEncryptionLevelHandshake, hs.suite.id, serverSecret) + if err := c.quicSetReadSecret(QUICEncryptionLevelHandshake, hs.suite.id, serverSecret); err != nil { + return err + } } err = c.config.writeKeyLog(keyLogLabelClientHandshake, hs.hello.random, clientSecret) @@ -677,7 +653,8 @@ func (hs *clientHandshakeStateTLS13) readServerCertificate() error { // See RFC 8446, Section 4.4.3. // We don't use hs.hello.supportedSignatureAlgorithms because it might // include PKCS#1 v1.5 and SHA-1 if the ClientHello also supported TLS 1.2. - if !isSupportedSignatureAlgorithm(certVerify.signatureAlgorithm, supportedSignatureAlgorithms(c.vers)) { + if !isSupportedSignatureAlgorithm(certVerify.signatureAlgorithm, supportedSignatureAlgorithms(c.vers)) || + !isSupportedSignatureAlgorithm(certVerify.signatureAlgorithm, signatureSchemesForPublicKey(c.vers, c.peerCertificates[0].PublicKey)) { c.sendAlert(alertIllegalParameter) return errors.New("tls: certificate used with invalid signature algorithm") } @@ -688,12 +665,13 @@ func (hs *clientHandshakeStateTLS13) readServerCertificate() error { if sigType == signaturePKCS1v15 || sigHash == crypto.SHA1 { return c.sendAlert(alertInternalError) } - signed := signedMessage(sigHash, serverSignatureContext, hs.transcript) + signed := signedMessage(serverSignatureContext, hs.transcript) if err := verifyHandshakeSignature(sigType, c.peerCertificates[0].PublicKey, sigHash, signed, certVerify.signature); err != nil { c.sendAlert(alertDecryptError) return errors.New("tls: invalid signature by the server certificate: " + err.Error()) } + c.peerSigAlg = certVerify.signatureAlgorithm if err := transcriptMsg(certVerify, hs.transcript); err != nil { return err @@ -733,7 +711,9 @@ func (hs *clientHandshakeStateTLS13) readServerFinished() error { hs.trafficSecret = hs.masterSecret.ClientApplicationTrafficSecret(hs.transcript) serverSecret := hs.masterSecret.ServerApplicationTrafficSecret(hs.transcript) - c.in.setTrafficSecret(hs.suite, QUICEncryptionLevelApplication, serverSecret) + if err := c.setReadTrafficSecret(hs.suite, QUICEncryptionLevelApplication, serverSecret, false); err != nil { + return err + } err = c.config.writeKeyLog(keyLogLabelClientTraffic, hs.hello.random, hs.trafficSecret) if err != nil { @@ -806,12 +786,12 @@ func (hs *clientHandshakeStateTLS13) sendClientCertificate() error { return c.sendAlert(alertInternalError) } - signed := signedMessage(sigHash, clientSignatureContext, hs.transcript) + signed := signedMessage(clientSignatureContext, hs.transcript) signOpts := crypto.SignerOpts(sigHash) if sigType == signatureRSAPSS { signOpts = &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: sigHash} } - sig, err := cert.PrivateKey.(crypto.Signer).Sign(c.config.rand(), signed, signOpts) + sig, err := crypto.SignMessage(cert.PrivateKey.(crypto.Signer), c.config.rand(), signed, signOpts) if err != nil { c.sendAlert(alertInternalError) return errors.New("tls: failed to sign handshake: " + err.Error()) @@ -836,16 +816,13 @@ func (hs *clientHandshakeStateTLS13) sendClientFinished() error { return err } - c.out.setTrafficSecret(hs.suite, QUICEncryptionLevelApplication, hs.trafficSecret) + c.setWriteTrafficSecret(hs.suite, QUICEncryptionLevelApplication, hs.trafficSecret) if !c.config.SessionTicketsDisabled && c.config.ClientSessionCache != nil { c.resumptionSecret = hs.masterSecret.ResumptionMasterSecret(hs.transcript) } if c.quic != nil { - if c.hand.Len() != 0 { - c.sendAlert(alertUnexpectedMessage) - } c.quicSetWriteSecret(QUICEncryptionLevelApplication, hs.suite.id, hs.trafficSecret) } diff --git a/handshake_messages.go b/handshake_messages.go index 52fcf02..3653035 100644 --- a/handshake_messages.go +++ b/handshake_messages.go @@ -1005,6 +1005,7 @@ type encryptedExtensionsMsg struct { quicTransportParameters []byte earlyData bool echRetryConfigs []byte + serverNameAck bool } func (m *encryptedExtensionsMsg) marshal() ([]byte, error) { @@ -1040,6 +1041,10 @@ func (m *encryptedExtensionsMsg) marshal() ([]byte, error) { b.AddBytes(m.echRetryConfigs) }) } + if m.serverNameAck { + b.AddUint16(extensionServerName) + b.AddUint16(0) // empty extension_data + } }) }) @@ -1095,6 +1100,11 @@ func (m *encryptedExtensionsMsg) unmarshal(data []byte) bool { if !extData.CopyBytes(m.echRetryConfigs) { return false } + case extensionServerName: + if len(extData) != 0 { + return false + } + m.serverNameAck = true default: // Ignore unknown extensions. continue diff --git a/handshake_server.go b/handshake_server.go index 5bf57bc..0079fd5 100644 --- a/handshake_server.go +++ b/handshake_server.go @@ -12,7 +12,6 @@ import ( "crypto/rsa" "crypto/subtle" "crypto/x509" - "encoding/binary" "errors" "fmt" "hash" @@ -354,7 +353,7 @@ func negotiateALPN(serverProtos, clientProtos []string, quic bool) (string, erro if http11fallback { return "", nil } - return "", fmt.Errorf("tls: client requested unsupported application protocols (%s)", clientProtos) + return "", fmt.Errorf("tls: client requested unsupported application protocols (%q)", clientProtos) } // supportsECDHE returns whether ECDHE key exchanges can be used with this @@ -511,8 +510,13 @@ func (hs *serverHandshakeState) checkForResumption() error { if sessionHasClientCerts && c.config.time().After(sessionState.peerCertificates[0].NotAfter) { return nil } + opts := x509.VerifyOptions{ + CurrentTime: c.config.time(), + Roots: c.config.ClientCAs, + KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, + } if sessionHasClientCerts && c.config.ClientAuth >= VerifyClientCertIfGiven && - len(sessionState.verifiedChains) == 0 { + !anyValidVerifiedChain(sessionState.verifiedChains, opts) { return nil } @@ -582,6 +586,10 @@ func (hs *serverHandshakeState) doFullHandshake() error { hs.hello.ocspStapling = true } + if hs.clientHello.serverName != "" { + hs.hello.serverNameAck = true + } + hs.hello.ticketSupported = hs.clientHello.ticketSupported && !c.config.SessionTicketsDisabled hs.hello.cipherSuite = hs.suite.id @@ -619,8 +627,9 @@ func (hs *serverHandshakeState) doFullHandshake() error { return err } if skx != nil { - if len(skx.key) >= 3 && skx.key[0] == 3 /* named curve */ { - c.curveID = CurveID(binary.BigEndian.Uint16(skx.key[1:])) + if keyAgreement, ok := keyAgreement.(*ecdheKeyAgreement); ok { + c.curveID = keyAgreement.curveID + c.peerSigAlg = keyAgreement.signatureAlgorithm } if _, err := hs.c.writeHandshakeRecord(skx, &hs.finishedHash); err != nil { return err @@ -759,19 +768,28 @@ func (hs *serverHandshakeState) doFullHandshake() error { if err != nil { return c.sendAlert(alertInternalError) } + if hs.finishedHash.buffer == nil { + c.sendAlert(alertInternalError) + return errors.New("tls: internal error: did not keep handshake transcript for TLS 1.2") + } + if err := verifyHandshakeSignature(sigType, pub, sigHash, hs.finishedHash.buffer, certVerify.signature); err != nil { + c.sendAlert(alertDecryptError) + return errors.New("tls: invalid signature by the client certificate: " + err.Error()) + } } else { sigType, sigHash, err = legacyTypeAndHashFromPublicKey(pub) if err != nil { c.sendAlert(alertIllegalParameter) return err } + signed := hs.finishedHash.hashForClientCertificate(sigType) + if err := verifyLegacyHandshakeSignature(sigType, pub, sigHash, signed, certVerify.signature); err != nil { + c.sendAlert(alertDecryptError) + return errors.New("tls: invalid signature by the client certificate: " + err.Error()) + } } - signed := hs.finishedHash.hashForClientCertificate(sigType, sigHash) - if err := verifyHandshakeSignature(sigType, pub, sigHash, signed, certVerify.signature); err != nil { - c.sendAlert(alertDecryptError) - return errors.New("tls: invalid signature by the client certificate: " + err.Error()) - } + c.peerSigAlg = certVerify.signatureAlgorithm if err := transcriptMsg(certVerify, &hs.finishedHash); err != nil { return err @@ -943,10 +961,9 @@ func (c *Conn) processCertsFromClient(certificate Certificate) error { chains, err := certs[0].Verify(opts) if err != nil { - var errCertificateInvalid x509.CertificateInvalidError - if errors.As(err, &x509.UnknownAuthorityError{}) { + if _, ok := errors.AsType[x509.UnknownAuthorityError](err); ok { c.sendAlert(alertUnknownCA) - } else if errors.As(err, &errCertificateInvalid) && errCertificateInvalid.Reason == x509.Expired { + } else if errCertificateInvalid, ok := errors.AsType[x509.CertificateInvalidError](err); ok && errCertificateInvalid.Reason == x509.Expired { c.sendAlert(alertCertificateExpired) } else { c.sendAlert(alertBadCertificate) @@ -990,6 +1007,10 @@ func clientHelloInfo(ctx context.Context, c *Conn, clientHello *clientHelloMsg) supportedVersions = supportedVersionsFromMax(clientHello.vers) } + conn := c.conn + if c.quic != nil { + conn = c.quic.clientHelloInfoConn + } return &ClientHelloInfo{ CipherSuites: clientHello.cipherSuites, ServerName: clientHello.serverName, @@ -999,7 +1020,8 @@ func clientHelloInfo(ctx context.Context, c *Conn, clientHello *clientHelloMsg) SupportedProtos: clientHello.alpnProtocols, SupportedVersions: supportedVersions, Extensions: clientHello.extensions, - Conn: c.conn, + Conn: conn, + HelloRetryRequest: c.didHRR, config: c.config, ctx: ctx, } diff --git a/handshake_server_tls13.go b/handshake_server_tls13.go index 5595662..0e1a0c5 100644 --- a/handshake_server_tls13.go +++ b/handshake_server_tls13.go @@ -11,6 +11,7 @@ import ( "crypto/ed25519" "crypto/hkdf" "crypto/hmac" + "crypto/hpke" "crypto/mlkem" "crypto/rand" "crypto/rsa" @@ -29,7 +30,6 @@ import ( "github.com/cloudflare/circl/sign/mldsa/mldsa65" "github.com/xtls/reality/fips140tls" - "github.com/xtls/reality/hpke" "github.com/xtls/reality/tls13" ) @@ -43,7 +43,7 @@ type echServerContext struct { configID uint8 ciphersuite EchCipher transcript hash.Hash - // inner indicates that the initial client_hello we recieved contained an + // inner indicates that the initial client_hello we received contained an // encrypted_client_hello extension that indicated it was an "inner" hello. // We don't do any additional processing of the hello in this case, so all // fields above are unset. @@ -340,55 +340,16 @@ func (hs *serverHandshakeStateTLS13) processClientHello() error { } c.curveID = selectedGroup - ecdhGroup := selectedGroup - ecdhData := clientKeyShare.data - if selectedGroup == X25519MLKEM768 { - ecdhGroup = X25519 - if len(ecdhData) != mlkem.EncapsulationKeySize768+x25519PublicKeySize { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: invalid X25519MLKEM768 client key share") - } - ecdhData = ecdhData[mlkem.EncapsulationKeySize768:] - } - if _, ok := curveForCurveID(ecdhGroup); !ok { - c.sendAlert(alertInternalError) - return errors.New("tls: CurvePreferences includes unsupported curve") - } - key, err := generateECDHEKey(c.config.rand(), ecdhGroup) + ke, err := keyExchangeForCurveID(selectedGroup) if err != nil { c.sendAlert(alertInternalError) - return err - } - hs.hello.serverShare = keyShare{group: selectedGroup, data: key.PublicKey().Bytes()} - peerKey, err := key.Curve().NewPublicKey(ecdhData) - if err != nil { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: invalid client key share") + return errors.New("tls: CurvePreferences includes unsupported curve") } - hs.sharedKey, err = key.ECDH(peerKey) + hs.sharedKey, hs.hello.serverShare, err = ke.serverSharedSecret(c.config.rand(), clientKeyShare.data) if err != nil { c.sendAlert(alertIllegalParameter) return errors.New("tls: invalid client key share") } - if selectedGroup == X25519MLKEM768 { - k, err := mlkem.NewEncapsulationKey768(clientKeyShare.data[:mlkem.EncapsulationKeySize768]) - if err != nil { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: invalid X25519MLKEM768 client key share") - } - mlkemSharedSecret, ciphertext := k.Encapsulate() - // draft-kwiatkowski-tls-ecdhe-mlkem-02, Section 3.1.3: "For - // X25519MLKEM768, the shared secret is the concatenation of the ML-KEM - // shared secret and the X25519 shared secret. The shared secret is 64 - // bytes (32 bytes for each part)." - hs.sharedKey = append(mlkemSharedSecret, hs.sharedKey...) - // draft-kwiatkowski-tls-ecdhe-mlkem-02, Section 3.1.2: "When the - // X25519MLKEM768 group is negotiated, the server's key exchange value - // is the concatenation of an ML-KEM ciphertext returned from - // encapsulation to the client's encapsulation key, and the server's - // ephemeral X25519 share." - hs.hello.serverShare.data = append(ciphertext, hs.hello.serverShare.data...) - } selectedProto, err := negotiateALPN(c.config.NextProtos, hs.clientHello.alpnProtocols, c.quic != nil) if err != nil { @@ -503,8 +464,13 @@ func (hs *serverHandshakeStateTLS13) checkForResumption() error { if sessionHasClientCerts && c.config.time().After(sessionState.peerCertificates[0].NotAfter) { continue } + opts := x509.VerifyOptions{ + CurrentTime: c.config.time(), + Roots: c.config.ClientCAs, + KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, + } if sessionHasClientCerts && c.config.ClientAuth >= VerifyClientCertIfGiven && - len(sessionState.verifiedChains) == 0 { + !anyValidVerifiedChain(sessionState.verifiedChains, opts) { continue } @@ -544,7 +510,9 @@ func (hs *serverHandshakeStateTLS13) checkForResumption() error { return err } earlyTrafficSecret := hs.earlySecret.ClientEarlyTrafficSecret(transcript) - c.quicSetReadSecret(QUICEncryptionLevelEarly, hs.suite.id, earlyTrafficSecret) + if err := c.quicSetReadSecret(QUICEncryptionLevelEarly, hs.suite.id, earlyTrafficSecret); err != nil { + return err + } } c.didResume = true @@ -562,10 +530,17 @@ func (hs *serverHandshakeStateTLS13) checkForResumption() error { return nil } -// cloneHash uses the encoding.BinaryMarshaler and encoding.BinaryUnmarshaler +// cloneHash uses [hash.Cloner] to clone in. If [hash.Cloner] +// is not implemented or not supported, then it falls back to the +// [encoding.BinaryMarshaler] and [encoding.BinaryUnmarshaler] // interfaces implemented by standard library hashes to clone the state of in // to a new instance of h. It returns nil if the operation fails. func cloneHash(in hash.Hash, h crypto.Hash) hash.Hash { + if cloner, ok := in.(hash.Cloner); ok { + if out, err := cloner.Clone(); err == nil { + return out + } + } // Recreate the interface to avoid importing encoding. type binaryMarshaler interface { MarshalBinary() (data []byte, err error) @@ -641,6 +616,14 @@ func (hs *serverHandshakeStateTLS13) sendDummyChangeCipherSpec() error { func (hs *serverHandshakeStateTLS13) doHelloRetryRequest(selectedGroup CurveID) (*keyShare, error) { c := hs.c + // Make sure the client didn't send extra handshake messages alongside + // their initial client_hello. If they sent two client_hello messages, + // we will consume the second before they respond to the server_hello. + if c.hand.Len() != 0 { + c.sendAlert(alertUnexpectedMessage) + return nil, errors.New("tls: handshake buffer not empty before HelloRetryRequest") + } + // The first ClientHello gets double-hashed into the transcript upon a // HelloRetryRequest. See RFC 8446, Section 4.4.1. if err := transcriptMsg(hs.clientHello, hs.transcript); err != nil { @@ -868,17 +851,18 @@ func (hs *serverHandshakeStateTLS13) sendServerParameters() error { } hs.handshakeSecret = earlySecret.HandshakeSecret(hs.sharedKey) - clientSecret := hs.handshakeSecret.ClientHandshakeTrafficSecret(hs.transcript) - c.in.setTrafficSecret(hs.suite, QUICEncryptionLevelHandshake, clientSecret) serverSecret := hs.handshakeSecret.ServerHandshakeTrafficSecret(hs.transcript) - c.out.setTrafficSecret(hs.suite, QUICEncryptionLevelHandshake, serverSecret) + c.setWriteTrafficSecret(hs.suite, QUICEncryptionLevelHandshake, serverSecret) + clientSecret := hs.handshakeSecret.ClientHandshakeTrafficSecret(hs.transcript) + if err := c.setReadTrafficSecret(hs.suite, QUICEncryptionLevelHandshake, clientSecret, false); err != nil { + return err + } if c.quic != nil { - if c.hand.Len() != 0 { - c.sendAlert(alertUnexpectedMessage) - } c.quicSetWriteSecret(QUICEncryptionLevelHandshake, hs.suite.id, serverSecret) - c.quicSetReadSecret(QUICEncryptionLevelHandshake, hs.suite.id, clientSecret) + if err := c.quicSetReadSecret(QUICEncryptionLevelHandshake, hs.suite.id, clientSecret); err != nil { + return err + } } err := c.config.writeKeyLog(keyLogLabelClientHandshake, hs.clientHello.random, clientSecret) @@ -904,6 +888,10 @@ func (hs *serverHandshakeStateTLS13) sendServerParameters() error { encryptedExtensions.earlyData = hs.earlyData } + if !hs.c.didResume && hs.clientHello.serverName != "" { + encryptedExtensions.serverNameAck = true + } + // If client sent ECH extension, but we didn't accept it, // send retry configs, if available. echKeys := hs.c.config.EncryptedClientHelloKeys @@ -976,12 +964,12 @@ func (hs *serverHandshakeStateTLS13) sendServerCertificate() error { return c.sendAlert(alertInternalError) } - signed := signedMessage(sigHash, serverSignatureContext, hs.transcript) + signed := signedMessage(serverSignatureContext, hs.transcript) signOpts := crypto.SignerOpts(sigHash) if sigType == signatureRSAPSS { signOpts = &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: sigHash} } - sig, err := hs.cert.PrivateKey.(crypto.Signer).Sign(c.config.rand(), signed, signOpts) + sig, err := crypto.SignMessage(hs.cert.PrivateKey.(crypto.Signer), c.config.rand(), signed, signOpts) if err != nil { public := hs.cert.PrivateKey.(crypto.Signer).Public() if rsaKey, ok := public.(*rsa.PublicKey); ok && sigType == signatureRSAPSS && @@ -1018,13 +1006,9 @@ func (hs *serverHandshakeStateTLS13) sendServerFinished() error { hs.trafficSecret = hs.masterSecret.ClientApplicationTrafficSecret(hs.transcript) serverSecret := hs.masterSecret.ServerApplicationTrafficSecret(hs.transcript) - c.out.setTrafficSecret(hs.suite, QUICEncryptionLevelApplication, serverSecret) + c.setWriteTrafficSecret(hs.suite, QUICEncryptionLevelApplication, serverSecret) if c.quic != nil { - if c.hand.Len() != 0 { - // TODO: Handle this in setTrafficSecret? - c.sendAlert(alertUnexpectedMessage) - } c.quicSetWriteSecret(QUICEncryptionLevelApplication, hs.suite.id, serverSecret) } @@ -1200,7 +1184,8 @@ func (hs *serverHandshakeStateTLS13) readClientCertificate() error { // See RFC 8446, Section 4.4.3. // We don't use certReq.supportedSignatureAlgorithms because it would // require keeping the certificateRequestMsgTLS13 around in the hs. - if !isSupportedSignatureAlgorithm(certVerify.signatureAlgorithm, supportedSignatureAlgorithms(c.vers)) { + if !isSupportedSignatureAlgorithm(certVerify.signatureAlgorithm, supportedSignatureAlgorithms(c.vers)) || + !isSupportedSignatureAlgorithm(certVerify.signatureAlgorithm, signatureSchemesForPublicKey(c.vers, c.peerCertificates[0].PublicKey)) { c.sendAlert(alertIllegalParameter) return errors.New("tls: client certificate used with invalid signature algorithm") } @@ -1211,12 +1196,13 @@ func (hs *serverHandshakeStateTLS13) readClientCertificate() error { if sigType == signaturePKCS1v15 || sigHash == crypto.SHA1 { return c.sendAlert(alertInternalError) } - signed := signedMessage(sigHash, clientSignatureContext, hs.transcript) + signed := signedMessage(clientSignatureContext, hs.transcript) if err := verifyHandshakeSignature(sigType, c.peerCertificates[0].PublicKey, sigHash, signed, certVerify.signature); err != nil { c.sendAlert(alertDecryptError) return errors.New("tls: invalid signature by the client certificate: " + err.Error()) } + c.peerSigAlg = certVerify.signatureAlgorithm if err := transcriptMsg(certVerify, hs.transcript); err != nil { return err @@ -1252,7 +1238,9 @@ func (hs *serverHandshakeStateTLS13) readClientFinished() error { return errors.New("tls: invalid client finished hash") } - c.in.setTrafficSecret(hs.suite, QUICEncryptionLevelApplication, hs.trafficSecret) + if err := c.setReadTrafficSecret(hs.suite, QUICEncryptionLevelApplication, hs.trafficSecret, false); err != nil { + return err + } return nil } diff --git a/hpke/hpye.go b/hpke/hpye.go deleted file mode 100644 index 9ef26be..0000000 --- a/hpke/hpye.go +++ /dev/null @@ -1,355 +0,0 @@ -// Copyright 2024 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package hpke - -import ( - "crypto" - "crypto/aes" - "crypto/cipher" - "crypto/ecdh" - "crypto/hkdf" - "crypto/rand" - "encoding/binary" - "errors" - "math/bits" - - "golang.org/x/crypto/chacha20poly1305" -) - -// testingOnlyGenerateKey is only used during testing, to provide -// a fixed test key to use when checking the RFC 9180 vectors. -var testingOnlyGenerateKey func() (*ecdh.PrivateKey, error) - -type hkdfKDF struct { - hash crypto.Hash -} - -func (kdf *hkdfKDF) LabeledExtract(sid []byte, salt []byte, label string, inputKey []byte) ([]byte, error) { - labeledIKM := make([]byte, 0, 7+len(sid)+len(label)+len(inputKey)) - labeledIKM = append(labeledIKM, []byte("HPKE-v1")...) - labeledIKM = append(labeledIKM, sid...) - labeledIKM = append(labeledIKM, label...) - labeledIKM = append(labeledIKM, inputKey...) - return hkdf.Extract(kdf.hash.New, labeledIKM, salt) -} - -func (kdf *hkdfKDF) LabeledExpand(suiteID []byte, randomKey []byte, label string, info []byte, length uint16) ([]byte, error) { - labeledInfo := make([]byte, 0, 2+7+len(suiteID)+len(label)+len(info)) - labeledInfo = binary.BigEndian.AppendUint16(labeledInfo, length) - labeledInfo = append(labeledInfo, []byte("HPKE-v1")...) - labeledInfo = append(labeledInfo, suiteID...) - labeledInfo = append(labeledInfo, label...) - labeledInfo = append(labeledInfo, info...) - return hkdf.Expand(kdf.hash.New, randomKey, string(labeledInfo), int(length)) -} - -// dhKEM implements the KEM specified in RFC 9180, Section 4.1. -type dhKEM struct { - dh ecdh.Curve - kdf hkdfKDF - - suiteID []byte - nSecret uint16 -} - -type KemID uint16 - -const DHKEM_X25519_HKDF_SHA256 = 0x0020 - -var SupportedKEMs = map[uint16]struct { - curve ecdh.Curve - hash crypto.Hash - nSecret uint16 -}{ - // RFC 9180 Section 7.1 - DHKEM_X25519_HKDF_SHA256: {ecdh.X25519(), crypto.SHA256, 32}, -} - -func newDHKem(kemID uint16) (*dhKEM, error) { - suite, ok := SupportedKEMs[kemID] - if !ok { - return nil, errors.New("unsupported suite ID") - } - return &dhKEM{ - dh: suite.curve, - kdf: hkdfKDF{suite.hash}, - suiteID: binary.BigEndian.AppendUint16([]byte("KEM"), kemID), - nSecret: suite.nSecret, - }, nil -} - -func (dh *dhKEM) ExtractAndExpand(dhKey, kemContext []byte) ([]byte, error) { - eaePRK, err := dh.kdf.LabeledExtract(dh.suiteID[:], nil, "eae_prk", dhKey) - if err != nil { - return nil, err - } - return dh.kdf.LabeledExpand(dh.suiteID[:], eaePRK, "shared_secret", kemContext, dh.nSecret) -} - -func (dh *dhKEM) Encap(pubRecipient *ecdh.PublicKey) (sharedSecret []byte, encapPub []byte, err error) { - var privEph *ecdh.PrivateKey - if testingOnlyGenerateKey != nil { - privEph, err = testingOnlyGenerateKey() - } else { - privEph, err = dh.dh.GenerateKey(rand.Reader) - } - if err != nil { - return nil, nil, err - } - dhVal, err := privEph.ECDH(pubRecipient) - if err != nil { - return nil, nil, err - } - encPubEph := privEph.PublicKey().Bytes() - - encPubRecip := pubRecipient.Bytes() - kemContext := append(encPubEph, encPubRecip...) - sharedSecret, err = dh.ExtractAndExpand(dhVal, kemContext) - if err != nil { - return nil, nil, err - } - return sharedSecret, encPubEph, nil -} - -func (dh *dhKEM) Decap(encPubEph []byte, secRecipient *ecdh.PrivateKey) ([]byte, error) { - pubEph, err := dh.dh.NewPublicKey(encPubEph) - if err != nil { - return nil, err - } - dhVal, err := secRecipient.ECDH(pubEph) - if err != nil { - return nil, err - } - kemContext := append(encPubEph, secRecipient.PublicKey().Bytes()...) - return dh.ExtractAndExpand(dhVal, kemContext) -} - -type context struct { - aead cipher.AEAD - - sharedSecret []byte - - suiteID []byte - - key []byte - baseNonce []byte - exporterSecret []byte - - seqNum uint128 -} - -type Sender struct { - *context -} - -type Recipient struct { - *context -} - -var aesGCMNew = func(key []byte) (cipher.AEAD, error) { - block, err := aes.NewCipher(key) - if err != nil { - return nil, err - } - return cipher.NewGCM(block) -} - -type AEADID uint16 - -const ( - AEAD_AES_128_GCM = 0x0001 - AEAD_AES_256_GCM = 0x0002 - AEAD_ChaCha20Poly1305 = 0x0003 -) - -var SupportedAEADs = map[uint16]struct { - keySize int - nonceSize int - aead func([]byte) (cipher.AEAD, error) -}{ - // RFC 9180, Section 7.3 - AEAD_AES_128_GCM: {keySize: 16, nonceSize: 12, aead: aesGCMNew}, - AEAD_AES_256_GCM: {keySize: 32, nonceSize: 12, aead: aesGCMNew}, - AEAD_ChaCha20Poly1305: {keySize: chacha20poly1305.KeySize, nonceSize: chacha20poly1305.NonceSize, aead: chacha20poly1305.New}, -} - -type KDFID uint16 - -const KDF_HKDF_SHA256 = 0x0001 - -var SupportedKDFs = map[uint16]func() *hkdfKDF{ - // RFC 9180, Section 7.2 - KDF_HKDF_SHA256: func() *hkdfKDF { return &hkdfKDF{crypto.SHA256} }, -} - -func newContext(sharedSecret []byte, kemID, kdfID, aeadID uint16, info []byte) (*context, error) { - sid := suiteID(kemID, kdfID, aeadID) - - kdfInit, ok := SupportedKDFs[kdfID] - if !ok { - return nil, errors.New("unsupported KDF id") - } - kdf := kdfInit() - - aeadInfo, ok := SupportedAEADs[aeadID] - if !ok { - return nil, errors.New("unsupported AEAD id") - } - - pskIDHash, err := kdf.LabeledExtract(sid, nil, "psk_id_hash", nil) - if err != nil { - return nil, err - } - infoHash, err := kdf.LabeledExtract(sid, nil, "info_hash", info) - if err != nil { - return nil, err - } - ksContext := append([]byte{0}, pskIDHash...) - ksContext = append(ksContext, infoHash...) - - secret, err := kdf.LabeledExtract(sid, sharedSecret, "secret", nil) - if err != nil { - return nil, err - } - key, err := kdf.LabeledExpand(sid, secret, "key", ksContext, uint16(aeadInfo.keySize) /* Nk - key size for AEAD */) - if err != nil { - return nil, err - } - baseNonce, err := kdf.LabeledExpand(sid, secret, "base_nonce", ksContext, uint16(aeadInfo.nonceSize) /* Nn - nonce size for AEAD */) - if err != nil { - return nil, err - } - exporterSecret, err := kdf.LabeledExpand(sid, secret, "exp", ksContext, uint16(kdf.hash.Size()) /* Nh - hash output size of the kdf*/) - if err != nil { - return nil, err - } - - aead, err := aeadInfo.aead(key) - if err != nil { - return nil, err - } - - return &context{ - aead: aead, - sharedSecret: sharedSecret, - suiteID: sid, - key: key, - baseNonce: baseNonce, - exporterSecret: exporterSecret, - }, nil -} - -func SetupSender(kemID, kdfID, aeadID uint16, pub *ecdh.PublicKey, info []byte) ([]byte, *Sender, error) { - kem, err := newDHKem(kemID) - if err != nil { - return nil, nil, err - } - sharedSecret, encapsulatedKey, err := kem.Encap(pub) - if err != nil { - return nil, nil, err - } - - context, err := newContext(sharedSecret, kemID, kdfID, aeadID, info) - if err != nil { - return nil, nil, err - } - - return encapsulatedKey, &Sender{context}, nil -} - -func SetupRecipient(kemID, kdfID, aeadID uint16, priv *ecdh.PrivateKey, info, encPubEph []byte) (*Recipient, error) { - kem, err := newDHKem(kemID) - if err != nil { - return nil, err - } - sharedSecret, err := kem.Decap(encPubEph, priv) - if err != nil { - return nil, err - } - - context, err := newContext(sharedSecret, kemID, kdfID, aeadID, info) - if err != nil { - return nil, err - } - - return &Recipient{context}, nil -} - -func (ctx *context) nextNonce() []byte { - nonce := ctx.seqNum.bytes()[16-ctx.aead.NonceSize():] - for i := range ctx.baseNonce { - nonce[i] ^= ctx.baseNonce[i] - } - return nonce -} - -func (ctx *context) incrementNonce() { - // Message limit is, according to the RFC, 2^95+1, which - // is somewhat confusing, but we do as we're told. - if ctx.seqNum.bitLen() >= (ctx.aead.NonceSize()*8)-1 { - panic("message limit reached") - } - ctx.seqNum = ctx.seqNum.addOne() -} - -func (s *Sender) Seal(aad, plaintext []byte) ([]byte, error) { - ciphertext := s.aead.Seal(nil, s.nextNonce(), plaintext, aad) - s.incrementNonce() - return ciphertext, nil -} - -func (r *Recipient) Open(aad, ciphertext []byte) ([]byte, error) { - plaintext, err := r.aead.Open(nil, r.nextNonce(), ciphertext, aad) - if err != nil { - return nil, err - } - r.incrementNonce() - return plaintext, nil -} - -func suiteID(kemID, kdfID, aeadID uint16) []byte { - suiteID := make([]byte, 0, 4+2+2+2) - suiteID = append(suiteID, []byte("HPKE")...) - suiteID = binary.BigEndian.AppendUint16(suiteID, kemID) - suiteID = binary.BigEndian.AppendUint16(suiteID, kdfID) - suiteID = binary.BigEndian.AppendUint16(suiteID, aeadID) - return suiteID -} - -func ParseHPKEPublicKey(kemID uint16, bytes []byte) (*ecdh.PublicKey, error) { - kemInfo, ok := SupportedKEMs[kemID] - if !ok { - return nil, errors.New("unsupported KEM id") - } - return kemInfo.curve.NewPublicKey(bytes) -} - -func ParseHPKEPrivateKey(kemID uint16, bytes []byte) (*ecdh.PrivateKey, error) { - kemInfo, ok := SupportedKEMs[kemID] - if !ok { - return nil, errors.New("unsupported KEM id") - } - return kemInfo.curve.NewPrivateKey(bytes) -} - -type uint128 struct { - hi, lo uint64 -} - -func (u uint128) addOne() uint128 { - lo, carry := bits.Add64(u.lo, 1, 0) - return uint128{u.hi + carry, lo} -} - -func (u uint128) bitLen() int { - return bits.Len64(u.hi) + bits.Len64(u.lo) -} - -func (u uint128) bytes() []byte { - b := make([]byte, 16) - binary.BigEndian.PutUint64(b[0:], u.hi) - binary.BigEndian.PutUint64(b[8:], u.lo) - return b -} \ No newline at end of file diff --git a/key_agreement.go b/key_agreement.go index 118f259..238597f 100644 --- a/key_agreement.go +++ b/key_agreement.go @@ -127,25 +127,8 @@ func md5SHA1Hash(slices [][]byte) []byte { } // hashForServerKeyExchange hashes the given slices and returns their digest -// using the given hash function (for TLS 1.2) or using a default based on -// the sigType (for earlier TLS versions). For Ed25519 signatures, which don't -// do pre-hashing, it returns the concatenation of the slices. -func hashForServerKeyExchange(sigType uint8, hashFunc crypto.Hash, version uint16, slices ...[]byte) []byte { - if sigType == signatureEd25519 { - var signed []byte - for _, slice := range slices { - signed = append(signed, slice...) - } - return signed - } - if version >= VersionTLS12 { - h := hashFunc.New() - for _, slice := range slices { - h.Write(slice) - } - digest := h.Sum(nil) - return digest - } +// using a hash based on the sigType. It can only be used for TLS 1.0 and 1.1. +func hashForServerKeyExchange(sigType uint8, slices ...[]byte) []byte { if sigType == signatureECDSA { return sha1Hash(slices) } @@ -159,31 +142,35 @@ func hashForServerKeyExchange(sigType uint8, hashFunc crypto.Hash, version uint1 type ecdheKeyAgreement struct { version uint16 isRSA bool - key *ecdh.PrivateKey // ckx and preMasterSecret are generated in processServerKeyExchange // and returned in generateClientKeyExchange. ckx *clientKeyExchangeMsg preMasterSecret []byte + + // curveID, signatureAlgorithm, and key are set by processServerKeyExchange + // and generateServerKeyExchange. + curveID CurveID + signatureAlgorithm SignatureScheme + key *ecdh.PrivateKey } func (ka *ecdheKeyAgreement) generateServerKeyExchange(config *Config, cert *Certificate, clientHello *clientHelloMsg, hello *serverHelloMsg) (*serverKeyExchangeMsg, error) { - var curveID CurveID for _, c := range clientHello.supportedCurves { if config.supportsCurve(ka.version, c) { - curveID = c + ka.curveID = c break } } - if curveID == 0 { + if ka.curveID == 0 { return nil, errors.New("tls: no supported elliptic curves offered") } - if _, ok := curveForCurveID(curveID); !ok { + if _, ok := curveForCurveID(ka.curveID); !ok { return nil, errors.New("tls: CurvePreferences includes unsupported curve") } - key, err := generateECDHEKey(config.rand(), curveID) + key, err := generateECDHEKey(config.rand(), ka.curveID) if err != nil { return nil, err } @@ -193,8 +180,8 @@ func (ka *ecdheKeyAgreement) generateServerKeyExchange(config *Config, cert *Cer ecdhePublic := key.PublicKey().Bytes() serverECDHEParams := make([]byte, 1+2+1+len(ecdhePublic)) serverECDHEParams[0] = 3 // named curve - serverECDHEParams[1] = byte(curveID >> 8) - serverECDHEParams[2] = byte(curveID) + serverECDHEParams[1] = byte(ka.curveID >> 8) + serverECDHEParams[2] = byte(ka.curveID) serverECDHEParams[3] = byte(len(ecdhePublic)) copy(serverECDHEParams[4:], ecdhePublic) @@ -203,37 +190,41 @@ func (ka *ecdheKeyAgreement) generateServerKeyExchange(config *Config, cert *Cer return nil, fmt.Errorf("tls: certificate private key of type %T does not implement crypto.Signer", cert.PrivateKey) } - var signatureAlgorithm SignatureScheme - var sigType uint8 - var sigHash crypto.Hash + var sig []byte if ka.version >= VersionTLS12 { - signatureAlgorithm, err = selectSignatureScheme(ka.version, cert, clientHello.supportedSignatureAlgorithms) + ka.signatureAlgorithm, err = selectSignatureScheme(ka.version, cert, clientHello.supportedSignatureAlgorithms) if err != nil { return nil, err } - sigType, sigHash, err = typeAndHashFromSignatureScheme(signatureAlgorithm) + sigType, sigHash, err := typeAndHashFromSignatureScheme(ka.signatureAlgorithm) if err != nil { return nil, err } + signed := slices.Concat(clientHello.random, hello.random, serverECDHEParams) + if (sigType == signaturePKCS1v15 || sigType == signatureRSAPSS) != ka.isRSA { + return nil, errors.New("tls: certificate cannot be used with the selected cipher suite") + } + signOpts := crypto.SignerOpts(sigHash) + if sigType == signatureRSAPSS { + signOpts = &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: sigHash} + } + sig, err = crypto.SignMessage(priv, config.rand(), signed, signOpts) + if err != nil { + return nil, errors.New("tls: failed to sign ECDHE parameters: " + err.Error()) + } } else { - sigType, sigHash, err = legacyTypeAndHashFromPublicKey(priv.Public()) + sigType, sigHash, err := legacyTypeAndHashFromPublicKey(priv.Public()) if err != nil { return nil, err } - } - if (sigType == signaturePKCS1v15 || sigType == signatureRSAPSS) != ka.isRSA { - return nil, errors.New("tls: certificate cannot be used with the selected cipher suite") - } - - signed := hashForServerKeyExchange(sigType, sigHash, ka.version, clientHello.random, hello.random, serverECDHEParams) - - signOpts := crypto.SignerOpts(sigHash) - if sigType == signatureRSAPSS { - signOpts = &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: sigHash} - } - sig, err := priv.Sign(config.rand(), signed, signOpts) - if err != nil { - return nil, errors.New("tls: failed to sign ECDHE parameters: " + err.Error()) + signed := hashForServerKeyExchange(sigType, clientHello.random, hello.random, serverECDHEParams) + if (sigType == signaturePKCS1v15) != ka.isRSA { + return nil, errors.New("tls: certificate cannot be used with the selected cipher suite") + } + sig, err = priv.Sign(config.rand(), signed, sigHash) + if err != nil { + return nil, errors.New("tls: failed to sign ECDHE parameters: " + err.Error()) + } } skx := new(serverKeyExchangeMsg) @@ -245,8 +236,8 @@ func (ka *ecdheKeyAgreement) generateServerKeyExchange(config *Config, cert *Cer copy(skx.key, serverECDHEParams) k := skx.key[len(serverECDHEParams):] if ka.version >= VersionTLS12 { - k[0] = byte(signatureAlgorithm >> 8) - k[1] = byte(signatureAlgorithm) + k[0] = byte(ka.signatureAlgorithm >> 8) + k[1] = byte(ka.signatureAlgorithm) k = k[2:] } k[0] = byte(len(sig) >> 8) @@ -280,7 +271,7 @@ func (ka *ecdheKeyAgreement) processServerKeyExchange(config *Config, clientHell if skx.key[0] != 3 { // named curve return errors.New("tls: server selected unsupported curve") } - curveID := CurveID(skx.key[1])<<8 | CurveID(skx.key[2]) + ka.curveID = CurveID(skx.key[1])<<8 | CurveID(skx.key[2]) publicLen := int(skx.key[3]) if publicLen+4 > len(skx.key) { @@ -293,16 +284,28 @@ func (ka *ecdheKeyAgreement) processServerKeyExchange(config *Config, clientHell if len(sig) < 2 { return errServerKeyExchange } + if ka.version >= VersionTLS12 { + ka.signatureAlgorithm = SignatureScheme(sig[0])<<8 | SignatureScheme(sig[1]) + sig = sig[2:] + if len(sig) < 2 { + return errServerKeyExchange + } + } + sigLen := int(sig[0])<<8 | int(sig[1]) + if sigLen+2 != len(sig) { + return errServerKeyExchange + } + sig = sig[2:] - if !slices.Contains(clientHello.supportedCurves, curveID) { + if !slices.Contains(clientHello.supportedCurves, ka.curveID) { return errors.New("tls: server selected unoffered curve") } - if _, ok := curveForCurveID(curveID); !ok { + if _, ok := curveForCurveID(ka.curveID); !ok { return errors.New("tls: server selected unsupported curve") } - key, err := generateECDHEKey(config.rand(), curveID) + key, err := generateECDHEKey(config.rand(), ka.curveID) if err != nil { return err } @@ -326,38 +329,32 @@ func (ka *ecdheKeyAgreement) processServerKeyExchange(config *Config, clientHell var sigType uint8 var sigHash crypto.Hash if ka.version >= VersionTLS12 { - signatureAlgorithm := SignatureScheme(sig[0])<<8 | SignatureScheme(sig[1]) - sig = sig[2:] - if len(sig) < 2 { - return errServerKeyExchange - } - - if !isSupportedSignatureAlgorithm(signatureAlgorithm, clientHello.supportedSignatureAlgorithms) { + if !isSupportedSignatureAlgorithm(ka.signatureAlgorithm, clientHello.supportedSignatureAlgorithms) { return errors.New("tls: certificate used with invalid signature algorithm") } - sigType, sigHash, err = typeAndHashFromSignatureScheme(signatureAlgorithm) + sigType, sigHash, err = typeAndHashFromSignatureScheme(ka.signatureAlgorithm) if err != nil { return err } + if (sigType == signaturePKCS1v15 || sigType == signatureRSAPSS) != ka.isRSA { + return errServerKeyExchange + } + signed := slices.Concat(clientHello.random, serverHello.random, serverECDHEParams) + if err := verifyHandshakeSignature(sigType, cert.PublicKey, sigHash, signed, sig); err != nil { + return errors.New("tls: invalid signature by the server certificate: " + err.Error()) + } } else { sigType, sigHash, err = legacyTypeAndHashFromPublicKey(cert.PublicKey) if err != nil { return err } - } - if (sigType == signaturePKCS1v15 || sigType == signatureRSAPSS) != ka.isRSA { - return errServerKeyExchange - } - - sigLen := int(sig[0])<<8 | int(sig[1]) - if sigLen+2 != len(sig) { - return errServerKeyExchange - } - sig = sig[2:] - - signed := hashForServerKeyExchange(sigType, sigHash, ka.version, clientHello.random, serverHello.random, serverECDHEParams) - if err := verifyHandshakeSignature(sigType, cert.PublicKey, sigHash, signed, sig); err != nil { - return errors.New("tls: invalid signature by the server certificate: " + err.Error()) + if (sigType == signaturePKCS1v15) != ka.isRSA { + return errServerKeyExchange + } + signed := hashForServerKeyExchange(sigType, clientHello.random, serverHello.random, serverECDHEParams) + if err := verifyLegacyHandshakeSignature(sigType, cert.PublicKey, sigHash, signed, sig); err != nil { + return errors.New("tls: invalid signature by the server certificate: " + err.Error()) + } } return nil } @@ -369,3 +366,29 @@ func (ka *ecdheKeyAgreement) generateClientKeyExchange(config *Config, clientHel return ka.preMasterSecret, ka.ckx, nil } + +// generateECDHEKey returns a PrivateKey that implements Diffie-Hellman +// according to RFC 8446, Section 4.2.8.2. +func generateECDHEKey(rand io.Reader, curveID CurveID) (*ecdh.PrivateKey, error) { + curve, ok := curveForCurveID(curveID) + if !ok { + return nil, errors.New("tls: internal error: unsupported curve") + } + + return curve.GenerateKey(rand) +} + +func curveForCurveID(id CurveID) (ecdh.Curve, bool) { + switch id { + case X25519: + return ecdh.X25519(), true + case CurveP256: + return ecdh.P256(), true + case CurveP384: + return ecdh.P384(), true + case CurveP521: + return ecdh.P521(), true + default: + return nil, false + } +} diff --git a/key_schedule.go b/key_schedule.go index acf4b96..b07a86f 100644 --- a/key_schedule.go +++ b/key_schedule.go @@ -5,7 +5,9 @@ package reality import ( + "crypto" "crypto/ecdh" + "crypto/fips140" "crypto/hmac" "crypto/mlkem" "errors" @@ -51,35 +53,222 @@ func (c *cipherSuiteTLS13) exportKeyingMaterial(s *tls13.MasterSecret, transcrip } type keySharePrivateKeys struct { - curveID CurveID - ecdhe *ecdh.PrivateKey - mlkem *mlkem.DecapsulationKey768 + ecdhe *ecdh.PrivateKey + mlkem crypto.Decapsulator } -const x25519PublicKeySize = 32 +// A keyExchange implements a TLS 1.3 KEM. +type keyExchange interface { + // keyShares generates one or two key shares. + // + // The first one will match the id, the second (if present) reuses the + // traditional component of the requested hybrid, as allowed by + // draft-ietf-tls-hybrid-design-09, Section 3.2. + keyShares(rand io.Reader) (*keySharePrivateKeys, []keyShare, error) -// generateECDHEKey returns a PrivateKey that implements Diffie-Hellman -// according to RFC 8446, Section 4.2.8.2. -func generateECDHEKey(rand io.Reader, curveID CurveID) (*ecdh.PrivateKey, error) { - curve, ok := curveForCurveID(curveID) - if !ok { - return nil, errors.New("tls: internal error: unsupported curve") - } + // serverSharedSecret computes the shared secret and the server's key share. + serverSharedSecret(rand io.Reader, clientKeyShare []byte) ([]byte, keyShare, error) - return curve.GenerateKey(rand) + // clientSharedSecret computes the shared secret given the server's key + // share and the keys generated by keyShares. + clientSharedSecret(priv *keySharePrivateKeys, serverKeyShare []byte) ([]byte, error) } -func curveForCurveID(id CurveID) (ecdh.Curve, bool) { +func keyExchangeForCurveID(id CurveID) (keyExchange, error) { + newMLKEMPrivateKey768 := func(b []byte) (crypto.Decapsulator, error) { + return mlkem.NewDecapsulationKey768(b) + } + newMLKEMPrivateKey1024 := func(b []byte) (crypto.Decapsulator, error) { + return mlkem.NewDecapsulationKey1024(b) + } + newMLKEMPublicKey768 := func(b []byte) (crypto.Encapsulator, error) { + return mlkem.NewEncapsulationKey768(b) + } + newMLKEMPublicKey1024 := func(b []byte) (crypto.Encapsulator, error) { + return mlkem.NewEncapsulationKey1024(b) + } switch id { case X25519: - return ecdh.X25519(), true + return &ecdhKeyExchange{id, ecdh.X25519()}, nil case CurveP256: - return ecdh.P256(), true + return &ecdhKeyExchange{id, ecdh.P256()}, nil case CurveP384: - return ecdh.P384(), true + return &ecdhKeyExchange{id, ecdh.P384()}, nil case CurveP521: - return ecdh.P521(), true + return &ecdhKeyExchange{id, ecdh.P521()}, nil + case X25519MLKEM768: + return &hybridKeyExchange{id, ecdhKeyExchange{X25519, ecdh.X25519()}, + 32, mlkem.EncapsulationKeySize768, mlkem.CiphertextSize768, + newMLKEMPrivateKey768, newMLKEMPublicKey768}, nil + case SecP256r1MLKEM768: + return &hybridKeyExchange{id, ecdhKeyExchange{CurveP256, ecdh.P256()}, + 65, mlkem.EncapsulationKeySize768, mlkem.CiphertextSize768, + newMLKEMPrivateKey768, newMLKEMPublicKey768}, nil + case SecP384r1MLKEM1024: + return &hybridKeyExchange{id, ecdhKeyExchange{CurveP384, ecdh.P384()}, + 97, mlkem.EncapsulationKeySize1024, mlkem.CiphertextSize1024, + newMLKEMPrivateKey1024, newMLKEMPublicKey1024}, nil default: - return nil, false + return nil, errors.New("tls: unsupported key exchange") + } +} + +type ecdhKeyExchange struct { + id CurveID + curve ecdh.Curve +} + +func (ke *ecdhKeyExchange) keyShares(rand io.Reader) (*keySharePrivateKeys, []keyShare, error) { + priv, err := ke.curve.GenerateKey(rand) + if err != nil { + return nil, nil, err + } + return &keySharePrivateKeys{ecdhe: priv}, []keyShare{{ke.id, priv.PublicKey().Bytes()}}, nil +} + +func (ke *ecdhKeyExchange) serverSharedSecret(rand io.Reader, clientKeyShare []byte) ([]byte, keyShare, error) { + key, err := ke.curve.GenerateKey(rand) + if err != nil { + return nil, keyShare{}, err + } + peerKey, err := ke.curve.NewPublicKey(clientKeyShare) + if err != nil { + return nil, keyShare{}, err + } + sharedKey, err := key.ECDH(peerKey) + if err != nil { + return nil, keyShare{}, err + } + return sharedKey, keyShare{ke.id, key.PublicKey().Bytes()}, nil +} + +func (ke *ecdhKeyExchange) clientSharedSecret(priv *keySharePrivateKeys, serverKeyShare []byte) ([]byte, error) { + peerKey, err := ke.curve.NewPublicKey(serverKeyShare) + if err != nil { + return nil, err + } + sharedKey, err := priv.ecdhe.ECDH(peerKey) + if err != nil { + return nil, err + } + return sharedKey, nil +} + +type hybridKeyExchange struct { + id CurveID + ecdh ecdhKeyExchange + + ecdhElementSize int + mlkemPublicKeySize int + mlkemCiphertextSize int + + newMLKEMPrivateKey func([]byte) (crypto.Decapsulator, error) + newMLKEMPublicKey func([]byte) (crypto.Encapsulator, error) +} + +func (ke *hybridKeyExchange) keyShares(rand io.Reader) (*keySharePrivateKeys, []keyShare, error) { + var ( + priv *keySharePrivateKeys + ecdhShares []keyShare + err error + ) + fips140.WithoutEnforcement(func() { // Hybrid of ML-KEM, which is Approved. + priv, ecdhShares, err = ke.ecdh.keyShares(rand) + }) + if err != nil { + return nil, nil, err + } + seed := make([]byte, mlkem.SeedSize) + if _, err := io.ReadFull(rand, seed); err != nil { + return nil, nil, err + } + priv.mlkem, err = ke.newMLKEMPrivateKey(seed) + if err != nil { + return nil, nil, err + } + var shareData []byte + // For X25519MLKEM768, the ML-KEM-768 encapsulation key comes first. + // For SecP256r1MLKEM768 and SecP384r1MLKEM1024, the ECDH share comes first. + // See draft-ietf-tls-ecdhe-mlkem-02, Section 4.1. + if ke.id == X25519MLKEM768 { + shareData = append(priv.mlkem.Encapsulator().Bytes(), ecdhShares[0].data...) + } else { + shareData = append(ecdhShares[0].data, priv.mlkem.Encapsulator().Bytes()...) + } + return priv, []keyShare{{ke.id, shareData}, ecdhShares[0]}, nil +} + +func (ke *hybridKeyExchange) serverSharedSecret(rand io.Reader, clientKeyShare []byte) ([]byte, keyShare, error) { + if len(clientKeyShare) != ke.ecdhElementSize+ke.mlkemPublicKeySize { + return nil, keyShare{}, errors.New("tls: invalid client key share length for hybrid key exchange") + } + var ecdhShareData, mlkemShareData []byte + if ke.id == X25519MLKEM768 { + mlkemShareData = clientKeyShare[:ke.mlkemPublicKeySize] + ecdhShareData = clientKeyShare[ke.mlkemPublicKeySize:] + } else { + ecdhShareData = clientKeyShare[:ke.ecdhElementSize] + mlkemShareData = clientKeyShare[ke.ecdhElementSize:] } -} \ No newline at end of file + var ( + ecdhSharedSecret []byte + ks keyShare + err error + ) + fips140.WithoutEnforcement(func() { // Hybrid of ML-KEM, which is Approved. + ecdhSharedSecret, ks, err = ke.ecdh.serverSharedSecret(rand, ecdhShareData) + }) + if err != nil { + return nil, keyShare{}, err + } + mlkemPeerKey, err := ke.newMLKEMPublicKey(mlkemShareData) + if err != nil { + return nil, keyShare{}, err + } + mlkemSharedSecret, mlkemKeyShare := mlkemPeerKey.Encapsulate() + var sharedKey []byte + if ke.id == X25519MLKEM768 { + sharedKey = append(mlkemSharedSecret, ecdhSharedSecret...) + ks.data = append(mlkemKeyShare, ks.data...) + } else { + sharedKey = append(ecdhSharedSecret, mlkemSharedSecret...) + ks.data = append(ks.data, mlkemKeyShare...) + } + ks.group = ke.id + return sharedKey, ks, nil +} + +func (ke *hybridKeyExchange) clientSharedSecret(priv *keySharePrivateKeys, serverKeyShare []byte) ([]byte, error) { + if len(serverKeyShare) != ke.ecdhElementSize+ke.mlkemCiphertextSize { + return nil, errors.New("tls: invalid server key share length for hybrid key exchange") + } + var ecdhShareData, mlkemShareData []byte + if ke.id == X25519MLKEM768 { + mlkemShareData = serverKeyShare[:ke.mlkemCiphertextSize] + ecdhShareData = serverKeyShare[ke.mlkemCiphertextSize:] + } else { + ecdhShareData = serverKeyShare[:ke.ecdhElementSize] + mlkemShareData = serverKeyShare[ke.ecdhElementSize:] + } + var ( + ecdhSharedSecret []byte + err error + ) + fips140.WithoutEnforcement(func() { // Hybrid of ML-KEM, which is Approved. + ecdhSharedSecret, err = ke.ecdh.clientSharedSecret(priv, ecdhShareData) + }) + if err != nil { + return nil, err + } + mlkemSharedSecret, err := priv.mlkem.Decapsulate(mlkemShareData) + if err != nil { + return nil, err + } + var sharedKey []byte + if ke.id == X25519MLKEM768 { + sharedKey = append(mlkemSharedSecret, ecdhSharedSecret...) + } else { + sharedKey = append(ecdhSharedSecret, mlkemSharedSecret...) + } + return sharedKey, nil +} diff --git a/prf.go b/prf.go index 290c4ca..f288642 100644 --- a/prf.go +++ b/prf.go @@ -222,22 +222,9 @@ func (h finishedHash) serverSum(masterSecret []byte) []byte { return h.prf(masterSecret, serverFinishedLabel, h.Sum(), finishedVerifyLength) } -// hashForClientCertificate returns the handshake messages so far, pre-hashed if -// necessary, suitable for signing by a TLS client certificate. -func (h finishedHash) hashForClientCertificate(sigType uint8, hashAlg crypto.Hash) []byte { - if (h.version >= VersionTLS12 || sigType == signatureEd25519) && h.buffer == nil { - panic("tls: handshake hash for a client certificate requested after discarding the handshake buffer") - } - - if sigType == signatureEd25519 { - return h.buffer - } - - if h.version >= VersionTLS12 { - hash := hashAlg.New() - hash.Write(h.buffer) - return hash.Sum(nil) - } +// hashForClientCertificate returns the handshake messages so far, pre-hashed, +// suitable for signing by a TLS 1.0 and 1.1 client certificate. +func (h finishedHash) hashForClientCertificate(sigType uint8) []byte { if sigType == signatureECDSA { return h.server.Sum(nil) diff --git a/quic.go b/quic.go index 83637e1..95a1995 100644 --- a/quic.go +++ b/quic.go @@ -8,6 +8,7 @@ import ( "context" "errors" "fmt" + "net" ) // QUICEncryptionLevel represents a QUIC encryption level used to transmit @@ -56,6 +57,9 @@ type QUICConfig struct { // stored in the client session cache. // The application should use [QUICConn.StoreSession] to store sessions. EnableSessionEvents bool + + // ClientHelloInfoConn is the net.Conn to use for the ClientHelloInfo.Conn field. + ClientHelloInfoConn net.Conn } // A QUICEventKind is a type of operation on a QUIC connection. @@ -117,6 +121,11 @@ const ( // The application may modify the [SessionState] before storing it. // This event only occurs on client connections. QUICStoreSession + + // QUICErrorEvent indicates that a fatal error has occurred. + // The handshake cannot proceed and the connection must be closed. + // QUICEvent.Err is set. + QUICErrorEvent ) // A QUICEvent is an event occurring on a QUIC connection. @@ -138,6 +147,10 @@ type QUICEvent struct { // Set for QUICResumeSession and QUICStoreSession. SessionState *SessionState + + // Set for QUICErrorEvent. + // The error will wrap AlertError. + Err error } type quicState struct { @@ -153,10 +166,11 @@ type quicState struct { started bool signalc chan struct{} // handshake data is available to be read blockedc chan struct{} // handshake is waiting for data, closed when done - cancelc <-chan struct{} // handshake has been canceled + ctx context.Context // handshake context cancel context.CancelFunc waitingForDrain bool + errorReturned bool // readbuf is shared between HandleData and the handshake goroutine. // HandshakeCryptoData passes ownership to the handshake goroutine by @@ -166,6 +180,7 @@ type quicState struct { transportParams []byte // to send to the peer enableSessionEvents bool + clientHelloInfoConn net.Conn } // QUICClient returns a new TLS client side connection using QUICTransport as the @@ -190,6 +205,7 @@ func newQUICConn(conn *Conn, config *QUICConfig) *QUICConn { signalc: make(chan struct{}), blockedc: make(chan struct{}), enableSessionEvents: config.EnableSessionEvents, + clientHelloInfoConn: config.ClientHelloInfoConn, } conn.quic.events = conn.quic.eventArr[:0] return &QUICConn{ @@ -222,7 +238,7 @@ func (q *QUICConn) NextEvent() QUICEvent { qs := q.conn.quic if last := qs.nextEvent - 1; last >= 0 && len(qs.events[last].Data) > 0 { // Write over some of the previous event's data, - // to catch callers erroniously retaining it. + // to catch callers erroneously retaining it. qs.events[last].Data[0] = 0 } if qs.nextEvent >= len(qs.events) && qs.waitingForDrain { @@ -230,6 +246,15 @@ func (q *QUICConn) NextEvent() QUICEvent { <-qs.signalc <-qs.blockedc } + if err := q.conn.handshakeErr; err != nil { + if qs.errorReturned { + return QUICEvent{Kind: QUICNoEvent} + } + qs.errorReturned = true + qs.events = nil + qs.nextEvent = 0 + return QUICEvent{Kind: QUICErrorEvent, Err: q.conn.handshakeErr} + } if qs.nextEvent >= len(qs.events) { qs.events = qs.events[:0] qs.nextEvent = 0 @@ -243,10 +268,11 @@ func (q *QUICConn) NextEvent() QUICEvent { // Close closes the connection and stops any in-progress handshake. func (q *QUICConn) Close() error { - if q.conn.quic.cancel == nil { + if q.conn.quic.ctx == nil { return nil // never started } q.conn.quic.cancel() + <-q.conn.quic.signalc for range q.conn.quic.blockedc { // Wait for the handshake goroutine to return. } @@ -303,6 +329,9 @@ type QUICSessionTicketOptions struct { // Currently, it can only be called once. func (q *QUICConn) SendSessionTicket(opts QUICSessionTicketOptions) error { c := q.conn + if c.config.SessionTicketsDisabled { + return nil + } if !c.isHandshakeComplete.Load() { return quicError(errors.New("tls: SendSessionTicket called before handshake completed")) } @@ -360,12 +389,11 @@ func quicError(err error) error { if err == nil { return nil } - var ae AlertError - if errors.As(err, &ae) { + if _, ok := errors.AsType[AlertError](err); ok { return err } - var a alert - if !errors.As(err, &a) { + a, ok := errors.AsType[alert](err) + if !ok { a = alertInternalError } // Return an error wrapping the original error and an AlertError. @@ -382,13 +410,22 @@ func (c *Conn) quicReadHandshakeBytes(n int) error { return nil } -func (c *Conn) quicSetReadSecret(level QUICEncryptionLevel, suite uint16, secret []byte) { +func (c *Conn) quicSetReadSecret(level QUICEncryptionLevel, suite uint16, secret []byte) error { + // Ensure that there are no buffered handshake messages before changing the + // read keys, since that can cause messages to be parsed that were encrypted + // using old keys which are no longer appropriate. + // TODO(roland): we should merge this check with the similar one in setReadTrafficSecret. + if c.hand.Len() != 0 { + c.sendAlert(alertUnexpectedMessage) + return errors.New("tls: handshake buffer not empty before setting read traffic secret") + } c.quic.events = append(c.quic.events, QUICEvent{ Kind: QUICSetReadSecret, Level: level, Suite: suite, Data: secret, }) + return nil } func (c *Conn) quicSetWriteSecret(level QUICEncryptionLevel, suite uint16, secret []byte) { @@ -482,20 +519,16 @@ func (c *Conn) quicWaitForSignal() error { // Send on blockedc to notify the QUICConn that the handshake is blocked. // Exported methods of QUICConn wait for the handshake to become blocked // before returning to the user. - select { - case c.quic.blockedc <- struct{}{}: - case <-c.quic.cancelc: - return c.sendAlertLocked(alertCloseNotify) - } + c.quic.blockedc <- struct{}{} // The QUICConn reads from signalc to notify us that the handshake may // be able to proceed. (The QUICConn reads, because we close signalc to // indicate that the handshake has completed.) - select { - case c.quic.signalc <- struct{}{}: - c.hand.Write(c.quic.readbuf) - c.quic.readbuf = nil - case <-c.quic.cancelc: + c.quic.signalc <- struct{}{} + if c.quic.ctx.Err() != nil { + // The connection has been canceled. return c.sendAlertLocked(alertCloseNotify) } + c.hand.Write(c.quic.readbuf) + c.quic.readbuf = nil return nil } \ No newline at end of file diff --git a/ticket.go b/ticket.go index d5aabb0..1154d80 100644 --- a/ticket.go +++ b/ticket.go @@ -81,7 +81,7 @@ type SessionState struct { version uint16 isClient bool cipherSuite uint16 - // createdAt is the generation time of the secret on the sever (which for + // createdAt is the generation time of the secret on the server (which for // TLS 1.0–1.2 might be earlier than the current session) and the time at // which the ticket was received on the client. createdAt uint64 // seconds since UNIX epoch diff --git a/tls.go b/tls.go index 4c8ef86..69cea8d 100644 --- a/tls.go +++ b/tls.go @@ -27,7 +27,6 @@ package reality // https://www.imperialviolet.org/2013/02/04/luckythirteen.html. import ( - "bytes" "context" "crypto" "crypto/aes" @@ -789,7 +788,7 @@ func X509KeyPair(certPEMBlock, keyPEMBlock []byte) (Certificate, error) { if !ok { return fail(errors.New("tls: private key type does not match public key type")) } - if pub.N.Cmp(priv.N) != 0 { + if !priv.PublicKey.Equal(pub) { return fail(errors.New("tls: private key does not match public key")) } case *ecdsa.PublicKey: @@ -797,7 +796,7 @@ func X509KeyPair(certPEMBlock, keyPEMBlock []byte) (Certificate, error) { if !ok { return fail(errors.New("tls: private key type does not match public key type")) } - if pub.X.Cmp(priv.X) != 0 || pub.Y.Cmp(priv.Y) != 0 { + if !priv.PublicKey.Equal(pub) { return fail(errors.New("tls: private key does not match public key")) } case ed25519.PublicKey: @@ -805,7 +804,7 @@ func X509KeyPair(certPEMBlock, keyPEMBlock []byte) (Certificate, error) { if !ok { return fail(errors.New("tls: private key type does not match public key type")) } - if !bytes.Equal(priv.Public().(ed25519.PublicKey), pub) { + if !priv.Public().(ed25519.PublicKey).Equal(pub) { return fail(errors.New("tls: private key does not match public key")) } default: