/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ /* vim: set ts=2 et sw=2 tw=80: */ /* This Source Code Form is subject to the terms of the Mozilla Public * License, v. 2.0. If a copy of the MPL was not distributed with this file, * You can obtain one at http://mozilla.org/MPL/2.0/. */ #include "tls_agent.h" #include "pk11func.h" #include "ssl.h" #include "sslerr.h" #include "sslproto.h" #include "keyhi.h" #define GTEST_HAS_RTTI 0 #include "gtest/gtest.h" namespace nss_test { const char* TlsAgent::states[] = {"INIT", "CONNECTING", "CONNECTED", "ERROR"}; TlsAgent::TlsAgent(const std::string& name, Role role, Mode mode, SSLKEAType kea) : name_(name), mode_(mode), kea_(kea), pr_fd_(nullptr), adapter_(nullptr), ssl_fd_(nullptr), role_(role), state_(STATE_INIT), falsestart_enabled_(false), expected_version_(0), expected_cipher_suite_(0), expect_resumption_(false), can_falsestart_hook_called_(false), sni_hook_called_(false), auth_certificate_hook_called_(false), handshake_callback_called_(false), error_code_(0), send_ctr_(0), recv_ctr_(0), expected_read_error_(false) { memset(&info_, 0, sizeof(info_)); memset(&csinfo_, 0, sizeof(csinfo_)); SECStatus rv = SSL_VersionRangeGetDefault(mode_ == STREAM ? ssl_variant_stream : ssl_variant_datagram, &vrange_); EXPECT_EQ(SECSuccess, rv); } TlsAgent::~TlsAgent() { if (adapter_) { Poller::Instance()->Cancel(READABLE_EVENT, adapter_); } if (pr_fd_) { PR_Close(pr_fd_); } if (ssl_fd_) { PR_Close(ssl_fd_); } } bool TlsAgent::EnsureTlsSetup() { // Don't set up twice if (ssl_fd_) return true; if (adapter_->mode() == STREAM) { ssl_fd_ = SSL_ImportFD(nullptr, pr_fd_); } else { ssl_fd_ = DTLS_ImportFD(nullptr, pr_fd_); } EXPECT_NE(nullptr, ssl_fd_); if (!ssl_fd_) return false; pr_fd_ = nullptr; if (role_ == SERVER) { CERTCertificate* cert = PK11_FindCertFromNickname(name_.c_str(), nullptr); EXPECT_NE(nullptr, cert); if (!cert) return false; SECKEYPrivateKey* priv = PK11_FindKeyByAnyCert(cert, nullptr); EXPECT_NE(nullptr, priv); if (!priv) return false; // Leak cert. SECStatus rv = SSL_ConfigSecureServer(ssl_fd_, cert, priv, kea_); EXPECT_EQ(SECSuccess, rv); if (rv != SECSuccess) return false; // Leak cert and key. SECKEY_DestroyPrivateKey(priv); CERT_DestroyCertificate(cert); rv = SSL_SNISocketConfigHook(ssl_fd_, SniHook, this); EXPECT_EQ(SECSuccess, rv); // don't abort, just fail } else { SECStatus rv = SSL_SetURL(ssl_fd_, "server"); EXPECT_EQ(SECSuccess, rv); if (rv != SECSuccess) return false; } SECStatus rv = SSL_VersionRangeSet(ssl_fd_, &vrange_); EXPECT_EQ(SECSuccess, rv); if (rv != SECSuccess) return false; rv = SSL_AuthCertificateHook(ssl_fd_, AuthCertificateHook, this); EXPECT_EQ(SECSuccess, rv); if (rv != SECSuccess) return false; rv = SSL_HandshakeCallback(ssl_fd_, HandshakeCallback, this); EXPECT_EQ(SECSuccess, rv); if (rv != SECSuccess) return false; return true; } void TlsAgent::SetupClientAuth() { EXPECT_TRUE(EnsureTlsSetup()); ASSERT_EQ(CLIENT, role_); EXPECT_EQ(SECSuccess, SSL_GetClientAuthDataHook(ssl_fd_, GetClientAuthDataHook, reinterpret_cast(this))); } bool TlsAgent::GetClientAuthCredentials(CERTCertificate **cert, SECKEYPrivateKey **priv) const { *cert = PK11_FindCertFromNickname(name_.c_str(), nullptr); EXPECT_NE(nullptr, *cert); if (!*cert) return false; *priv = PK11_FindKeyByAnyCert(*cert, nullptr); EXPECT_NE(nullptr, *priv); if (!*priv) return false; // Leak cert. return true; } SECStatus TlsAgent::GetClientAuthDataHook(void* self, PRFileDesc* fd, CERTDistNames* caNames, CERTCertificate** cert, SECKEYPrivateKey** privKey) { TlsAgent* agent = reinterpret_cast(self); if (agent->GetClientAuthCredentials(cert, privKey)) { return SECSuccess; } return SECFailure; } void TlsAgent::RequestClientAuth(bool requireAuth) { EXPECT_TRUE(EnsureTlsSetup()); ASSERT_EQ(SERVER, role_); EXPECT_EQ(SECSuccess, SSL_OptionSet(ssl_fd_, SSL_REQUEST_CERTIFICATE, PR_TRUE)); EXPECT_EQ(SECSuccess, SSL_OptionSet(ssl_fd_, SSL_REQUIRE_CERTIFICATE, requireAuth ? PR_TRUE : PR_FALSE)); EXPECT_EQ(SECSuccess, SSL_AuthCertificateHook(ssl_fd_, &TlsAgent::ClientAuthenticated, this)); expect_client_auth_ = true; } void TlsAgent::StartConnect() { EXPECT_TRUE(EnsureTlsSetup()); SECStatus rv; rv = SSL_ResetHandshake(ssl_fd_, role_ == SERVER ? PR_TRUE : PR_FALSE); EXPECT_EQ(SECSuccess, rv); SetState(STATE_CONNECTING); } void TlsAgent::DisableCiphersByKeyExchange(SSLKEAType kea) { EXPECT_TRUE(EnsureTlsSetup()); for (size_t i = 0; i < SSL_NumImplementedCiphers; ++i) { SSLCipherSuiteInfo csinfo; SECStatus rv = SSL_GetCipherSuiteInfo(SSL_ImplementedCiphers[i], &csinfo, sizeof(csinfo)); ASSERT_EQ(SECSuccess, rv); if (csinfo.keaType == kea) { rv = SSL_CipherPrefSet(ssl_fd_, SSL_ImplementedCiphers[i], PR_FALSE); EXPECT_EQ(SECSuccess, rv); } } } void TlsAgent::SetSessionTicketsEnabled(bool en) { EXPECT_TRUE(EnsureTlsSetup()); SECStatus rv = SSL_OptionSet(ssl_fd_, SSL_ENABLE_SESSION_TICKETS, en ? PR_TRUE : PR_FALSE); EXPECT_EQ(SECSuccess, rv); } void TlsAgent::SetSessionCacheEnabled(bool en) { EXPECT_TRUE(EnsureTlsSetup()); SECStatus rv = SSL_OptionSet(ssl_fd_, SSL_NO_CACHE, en ? PR_FALSE : PR_TRUE); EXPECT_EQ(SECSuccess, rv); } void TlsAgent::SetVersionRange(uint16_t minver, uint16_t maxver) { vrange_.min = minver; vrange_.max = maxver; if (ssl_fd_) { SECStatus rv = SSL_VersionRangeSet(ssl_fd_, &vrange_); EXPECT_EQ(SECSuccess, rv); } } void TlsAgent::SetExpectedVersion(uint16_t version) { expected_version_ = version; } void TlsAgent::SetExpectedReadError(bool err) { expected_read_error_ = err; } void TlsAgent::SetSignatureAlgorithms(const SSLSignatureAndHashAlg* algorithms, size_t count) { EXPECT_TRUE(EnsureTlsSetup()); EXPECT_LE(count, SSL_SignatureMaxCount()); EXPECT_EQ(SECSuccess, SSL_SignaturePrefSet(ssl_fd_, algorithms, static_cast(count))); EXPECT_EQ(SECFailure, SSL_SignaturePrefSet(ssl_fd_, algorithms, 0)) << "setting no algorithms should fail and do nothing"; std::vector configuredAlgorithms(count); unsigned int configuredCount; EXPECT_EQ(SECFailure, SSL_SignaturePrefGet(ssl_fd_, nullptr, &configuredCount, 1)) << "get algorithms, algorithms is nullptr"; EXPECT_EQ(SECFailure, SSL_SignaturePrefGet(ssl_fd_, &configuredAlgorithms[0], &configuredCount, 0)) << "get algorithms, too little space"; EXPECT_EQ(SECFailure, SSL_SignaturePrefGet(ssl_fd_, &configuredAlgorithms[0], nullptr, configuredAlgorithms.size())) << "get algorithms, algCountOut is nullptr"; EXPECT_EQ(SECSuccess, SSL_SignaturePrefGet(ssl_fd_, &configuredAlgorithms[0], &configuredCount, configuredAlgorithms.size())); // SignaturePrefSet drops unsupported algorithms silently, so the number that // are configured might be fewer. EXPECT_LE(configuredCount, count); unsigned int i = 0; for (unsigned int j = 0; j < count && i < configuredCount; ++j) { if (i < configuredCount && algorithms[j].hashAlg == configuredAlgorithms[i].hashAlg && algorithms[j].sigAlg == configuredAlgorithms[i].sigAlg) { ++i; } } EXPECT_EQ(i, configuredCount) << "algorithms in use were all set"; } void TlsAgent::CheckKEAType(SSLKEAType type) const { EXPECT_EQ(STATE_CONNECTED, state_); EXPECT_EQ(type, csinfo_.keaType); } void TlsAgent::CheckAuthType(SSLAuthType type) const { EXPECT_EQ(STATE_CONNECTED, state_); EXPECT_EQ(type, csinfo_.authAlgorithm); } void TlsAgent::EnableFalseStart() { EXPECT_TRUE(EnsureTlsSetup()); falsestart_enabled_ = true; EXPECT_EQ(SECSuccess, SSL_SetCanFalseStartCallback(ssl_fd_, CanFalseStartCallback, this)); EXPECT_EQ(SECSuccess, SSL_OptionSet(ssl_fd_, SSL_ENABLE_FALSE_START, PR_TRUE)); } void TlsAgent::ExpectResumption() { expect_resumption_ = true; } void TlsAgent::EnableAlpn(const uint8_t* val, size_t len) { EXPECT_TRUE(EnsureTlsSetup()); EXPECT_EQ(SECSuccess, SSL_OptionSet(ssl_fd_, SSL_ENABLE_ALPN, PR_TRUE)); EXPECT_EQ(SECSuccess, SSL_SetNextProtoNego(ssl_fd_, val, len)); } void TlsAgent::CheckAlpn(SSLNextProtoState expected_state, const std::string& expected) const { SSLNextProtoState state; char chosen[10]; unsigned int chosen_len; SECStatus rv = SSL_GetNextProto(ssl_fd_, &state, reinterpret_cast(chosen), &chosen_len, sizeof(chosen)); EXPECT_EQ(SECSuccess, rv); EXPECT_EQ(expected_state, state); EXPECT_EQ(expected, std::string(chosen, chosen_len)); } void TlsAgent::EnableSrtp() { EXPECT_TRUE(EnsureTlsSetup()); const uint16_t ciphers[] = { SRTP_AES128_CM_HMAC_SHA1_80, SRTP_AES128_CM_HMAC_SHA1_32 }; EXPECT_EQ(SECSuccess, SSL_SetSRTPCiphers(ssl_fd_, ciphers, PR_ARRAY_SIZE(ciphers))); } void TlsAgent::CheckSrtp() const { uint16_t actual; EXPECT_EQ(SECSuccess, SSL_GetSRTPCipher(ssl_fd_, &actual)); EXPECT_EQ(SRTP_AES128_CM_HMAC_SHA1_80, actual); } void TlsAgent::CheckErrorCode(int32_t expected) const { EXPECT_EQ(STATE_ERROR, state_); EXPECT_EQ(expected, error_code_); } void TlsAgent::CheckPreliminaryInfo() { SSLPreliminaryChannelInfo info; EXPECT_EQ(SECSuccess, SSL_GetPreliminaryChannelInfo(ssl_fd_, &info, sizeof(info))); EXPECT_TRUE(info.valuesSet & ssl_preinfo_version); EXPECT_TRUE(info.valuesSet & ssl_preinfo_cipher_suite); // A version of 0 is invalid and indicates no expectation. This value is // initialized to 0 so that tests that don't explicitly set an expected // version can negotiate a version. if (!expected_version_) { expected_version_ = info.protocolVersion; } EXPECT_EQ(expected_version_, info.protocolVersion); // As with the version; 0 is the null cipher suite (and also invalid). if (!expected_cipher_suite_) { expected_cipher_suite_ = info.cipherSuite; } EXPECT_EQ(expected_cipher_suite_, info.cipherSuite); } // Check that all the expected callbacks have been called. void TlsAgent::CheckCallbacks() const { // If false start happens, the handshake is reported as being complete at the // point that false start happens. if (expect_resumption_ || !falsestart_enabled_) { EXPECT_TRUE(handshake_callback_called_); } // These callbacks shouldn't fire if we are resuming. if (role_ == SERVER) { EXPECT_EQ(!expect_resumption_, sni_hook_called_); } else { EXPECT_EQ(!expect_resumption_, auth_certificate_hook_called_); // Note that this isn't unconditionally called, even with false start on. // But the callback is only skipped if a cipher that is ridiculously weak // (80 bits) is chosen. Don't test that: plan to remove bad ciphers. EXPECT_EQ(falsestart_enabled_ && !expect_resumption_, can_falsestart_hook_called_); } } void TlsAgent::Connected() { LOG("Handshake success"); CheckCallbacks(); SECStatus rv = SSL_GetChannelInfo(ssl_fd_, &info_, sizeof(info_)); EXPECT_EQ(SECSuccess, rv); // Preliminary values are exposed through callbacks during the handshake. // If either expected values were set or the callbacks were called, check // that the final values are correct. EXPECT_EQ(expected_version_, info_.protocolVersion); EXPECT_EQ(expected_cipher_suite_, info_.cipherSuite); rv = SSL_GetCipherSuiteInfo(info_.cipherSuite, &csinfo_, sizeof(csinfo_)); EXPECT_EQ(SECSuccess, rv); SetState(STATE_CONNECTED); } void TlsAgent::EnableExtendedMasterSecret() { ASSERT_TRUE(EnsureTlsSetup()); SECStatus rv = SSL_OptionSet(ssl_fd_, SSL_ENABLE_EXTENDED_MASTER_SECRET, PR_TRUE); ASSERT_EQ(SECSuccess, rv); } void TlsAgent::CheckExtendedMasterSecret(bool expected) { ASSERT_EQ(expected, info_.extendedMasterSecretUsed != PR_FALSE) << "unexpected extended master secret state for " << name_; } void TlsAgent::DisableRollbackDetection() { ASSERT_TRUE(EnsureTlsSetup()); SECStatus rv = SSL_OptionSet(ssl_fd_, SSL_ROLLBACK_DETECTION, PR_FALSE); ASSERT_EQ(SECSuccess, rv); } void TlsAgent::Handshake() { SECStatus rv = SSL_ForceHandshake(ssl_fd_); if (rv == SECSuccess) { Connected(); Poller::Instance()->Wait(READABLE_EVENT, adapter_, this, &TlsAgent::ReadableCallback); return; } int32_t err = PR_GetError(); switch (err) { case PR_WOULD_BLOCK_ERROR: LOG("Would have blocked"); // TODO(ekr@rtfm.com): set DTLS timeouts Poller::Instance()->Wait(READABLE_EVENT, adapter_, this, &TlsAgent::ReadableCallback); return; break; // TODO(ekr@rtfm.com): needs special case for DTLS case SSL_ERROR_RX_MALFORMED_HANDSHAKE: default: if (IS_SSL_ERROR(err)) { LOG("Handshake failed with SSL error " << err - SSL_ERROR_BASE); } else { LOG("Handshake failed with error " << err); } error_code_ = err; SetState(STATE_ERROR); return; } } void TlsAgent::PrepareForRenegotiate() { EXPECT_EQ(STATE_CONNECTED, state_); SetState(STATE_CONNECTING); } void TlsAgent::StartRenegotiate() { PrepareForRenegotiate(); SECStatus rv = SSL_ReHandshake(ssl_fd_, PR_TRUE); EXPECT_EQ(SECSuccess, rv); } void TlsAgent::SendData(size_t bytes, size_t blocksize) { uint8_t block[4096]; ASSERT_LT(blocksize, sizeof(block)); while(bytes) { size_t tosend = std::min(blocksize, bytes); for(size_t i = 0; i < tosend; ++i) { block[i] = 0xff & send_ctr_; ++send_ctr_; } LOG("Writing " << tosend << " bytes"); int32_t rv = PR_Write(ssl_fd_, block, tosend); ASSERT_EQ(tosend, static_cast(rv)); bytes -= tosend; } } void TlsAgent::ReadBytes() { uint8_t block[1024]; LOG("Reading application data from socket"); int32_t rv = PR_Read(ssl_fd_, block, sizeof(block)); int32_t err = PR_GetError(); if (err != PR_WOULD_BLOCK_ERROR) { if (expected_read_error_) { error_code_ = err; } else { ASSERT_LE(0, rv); size_t count = static_cast(rv); LOG("Read " << count << " bytes"); for (size_t i = 0; i < count; ++i) { ASSERT_EQ(recv_ctr_ & 0xff, block[i]); recv_ctr_++; } } } Poller::Instance()->Wait(READABLE_EVENT, adapter_, this, &TlsAgent::ReadableCallback); } void TlsAgent::ResetSentBytes() { send_ctr_ = 0; } void TlsAgent::ConfigureSessionCache(SessionResumptionMode mode) { EXPECT_TRUE(EnsureTlsSetup()); SECStatus rv = SSL_OptionSet(ssl_fd_, SSL_NO_CACHE, mode & RESUME_SESSIONID ? PR_FALSE : PR_TRUE); EXPECT_EQ(SECSuccess, rv); rv = SSL_OptionSet(ssl_fd_, SSL_ENABLE_SESSION_TICKETS, mode & RESUME_TICKET ? PR_TRUE : PR_FALSE); EXPECT_EQ(SECSuccess, rv); } static const std::string kTlsRolesAllArr[] = {"CLIENT", "SERVER"}; ::testing::internal::ParamGenerator TlsAgentTestBase::kTlsRolesAll = ::testing::ValuesIn(kTlsRolesAllArr); void TlsAgentTestBase::Init() { agent_ = new TlsAgent( role_ == TlsAgent::CLIENT ? "client" : "server", role_, mode_, kea_); agent_->Init(); fd_ = DummyPrSocket::CreateFD("dummy", mode_); agent_->adapter()->SetPeer( DummyPrSocket::GetAdapter(fd_)); agent_->StartConnect(); } void TlsAgentTestBase::EnsureInit() { if (!agent_) { Init(); } } void TlsAgentTestBase::ProcessMessage(const DataBuffer& buffer, TlsAgent::State expected_state, int32_t error_code) { EnsureInit(); agent_->adapter()->PacketReceived(buffer); agent_->Handshake(); ASSERT_EQ(expected_state, agent_->state()); if (expected_state == TlsAgent::STATE_ERROR) { ASSERT_EQ(error_code, agent_->error_code()); } } } // namespace nss_test