mirror of
https://github.com/classilla/tenfourfox.git
synced 2024-10-10 13:23:42 +00:00
573 lines
17 KiB
C++
573 lines
17 KiB
C++
/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
|
|
/* vim: set ts=2 et sw=2 tw=80: */
|
|
/* This Source Code Form is subject to the terms of the Mozilla Public
|
|
* License, v. 2.0. If a copy of the MPL was not distributed with this file,
|
|
* You can obtain one at http://mozilla.org/MPL/2.0/. */
|
|
|
|
#include "tls_agent.h"
|
|
|
|
#include "pk11func.h"
|
|
#include "ssl.h"
|
|
#include "sslerr.h"
|
|
#include "sslproto.h"
|
|
#include "keyhi.h"
|
|
|
|
#define GTEST_HAS_RTTI 0
|
|
#include "gtest/gtest.h"
|
|
|
|
namespace nss_test {
|
|
|
|
|
|
const char* TlsAgent::states[] = {"INIT", "CONNECTING", "CONNECTED", "ERROR"};
|
|
|
|
TlsAgent::TlsAgent(const std::string& name, Role role, Mode mode, SSLKEAType kea)
|
|
: name_(name),
|
|
mode_(mode),
|
|
kea_(kea),
|
|
pr_fd_(nullptr),
|
|
adapter_(nullptr),
|
|
ssl_fd_(nullptr),
|
|
role_(role),
|
|
state_(STATE_INIT),
|
|
falsestart_enabled_(false),
|
|
expected_version_(0),
|
|
expected_cipher_suite_(0),
|
|
expect_resumption_(false),
|
|
can_falsestart_hook_called_(false),
|
|
sni_hook_called_(false),
|
|
auth_certificate_hook_called_(false),
|
|
handshake_callback_called_(false),
|
|
error_code_(0),
|
|
send_ctr_(0),
|
|
recv_ctr_(0),
|
|
expected_read_error_(false) {
|
|
|
|
memset(&info_, 0, sizeof(info_));
|
|
memset(&csinfo_, 0, sizeof(csinfo_));
|
|
SECStatus rv = SSL_VersionRangeGetDefault(mode_ == STREAM ?
|
|
ssl_variant_stream : ssl_variant_datagram,
|
|
&vrange_);
|
|
EXPECT_EQ(SECSuccess, rv);
|
|
}
|
|
|
|
TlsAgent::~TlsAgent() {
|
|
if (adapter_) {
|
|
Poller::Instance()->Cancel(READABLE_EVENT, adapter_);
|
|
}
|
|
|
|
if (pr_fd_) {
|
|
PR_Close(pr_fd_);
|
|
}
|
|
|
|
if (ssl_fd_) {
|
|
PR_Close(ssl_fd_);
|
|
}
|
|
}
|
|
|
|
bool TlsAgent::EnsureTlsSetup() {
|
|
// Don't set up twice
|
|
if (ssl_fd_) return true;
|
|
|
|
if (adapter_->mode() == STREAM) {
|
|
ssl_fd_ = SSL_ImportFD(nullptr, pr_fd_);
|
|
} else {
|
|
ssl_fd_ = DTLS_ImportFD(nullptr, pr_fd_);
|
|
}
|
|
|
|
EXPECT_NE(nullptr, ssl_fd_);
|
|
if (!ssl_fd_) return false;
|
|
pr_fd_ = nullptr;
|
|
|
|
if (role_ == SERVER) {
|
|
CERTCertificate* cert = PK11_FindCertFromNickname(name_.c_str(), nullptr);
|
|
EXPECT_NE(nullptr, cert);
|
|
if (!cert) return false;
|
|
|
|
SECKEYPrivateKey* priv = PK11_FindKeyByAnyCert(cert, nullptr);
|
|
EXPECT_NE(nullptr, priv);
|
|
if (!priv) return false; // Leak cert.
|
|
|
|
SECStatus rv = SSL_ConfigSecureServer(ssl_fd_, cert, priv, kea_);
|
|
EXPECT_EQ(SECSuccess, rv);
|
|
if (rv != SECSuccess) return false; // Leak cert and key.
|
|
|
|
SECKEY_DestroyPrivateKey(priv);
|
|
CERT_DestroyCertificate(cert);
|
|
|
|
rv = SSL_SNISocketConfigHook(ssl_fd_, SniHook, this);
|
|
EXPECT_EQ(SECSuccess, rv); // don't abort, just fail
|
|
} else {
|
|
SECStatus rv = SSL_SetURL(ssl_fd_, "server");
|
|
EXPECT_EQ(SECSuccess, rv);
|
|
if (rv != SECSuccess) return false;
|
|
}
|
|
|
|
SECStatus rv = SSL_VersionRangeSet(ssl_fd_, &vrange_);
|
|
EXPECT_EQ(SECSuccess, rv);
|
|
if (rv != SECSuccess) return false;
|
|
|
|
rv = SSL_AuthCertificateHook(ssl_fd_, AuthCertificateHook, this);
|
|
EXPECT_EQ(SECSuccess, rv);
|
|
if (rv != SECSuccess) return false;
|
|
|
|
rv = SSL_HandshakeCallback(ssl_fd_, HandshakeCallback, this);
|
|
EXPECT_EQ(SECSuccess, rv);
|
|
if (rv != SECSuccess) return false;
|
|
|
|
return true;
|
|
}
|
|
|
|
void TlsAgent::SetupClientAuth() {
|
|
EXPECT_TRUE(EnsureTlsSetup());
|
|
ASSERT_EQ(CLIENT, role_);
|
|
|
|
EXPECT_EQ(SECSuccess,
|
|
SSL_GetClientAuthDataHook(ssl_fd_, GetClientAuthDataHook,
|
|
reinterpret_cast<void*>(this)));
|
|
}
|
|
|
|
bool TlsAgent::GetClientAuthCredentials(CERTCertificate **cert,
|
|
SECKEYPrivateKey **priv) const {
|
|
*cert = PK11_FindCertFromNickname(name_.c_str(), nullptr);
|
|
EXPECT_NE(nullptr, *cert);
|
|
if (!*cert) return false;
|
|
|
|
*priv = PK11_FindKeyByAnyCert(*cert, nullptr);
|
|
EXPECT_NE(nullptr, *priv);
|
|
if (!*priv) return false; // Leak cert.
|
|
|
|
return true;
|
|
}
|
|
|
|
SECStatus TlsAgent::GetClientAuthDataHook(void* self, PRFileDesc* fd,
|
|
CERTDistNames* caNames,
|
|
CERTCertificate** cert,
|
|
SECKEYPrivateKey** privKey) {
|
|
TlsAgent* agent = reinterpret_cast<TlsAgent*>(self);
|
|
if (agent->GetClientAuthCredentials(cert, privKey)) {
|
|
return SECSuccess;
|
|
}
|
|
return SECFailure;
|
|
}
|
|
|
|
|
|
void TlsAgent::RequestClientAuth(bool requireAuth) {
|
|
EXPECT_TRUE(EnsureTlsSetup());
|
|
ASSERT_EQ(SERVER, role_);
|
|
|
|
EXPECT_EQ(SECSuccess,
|
|
SSL_OptionSet(ssl_fd_, SSL_REQUEST_CERTIFICATE, PR_TRUE));
|
|
EXPECT_EQ(SECSuccess,
|
|
SSL_OptionSet(ssl_fd_, SSL_REQUIRE_CERTIFICATE,
|
|
requireAuth ? PR_TRUE : PR_FALSE));
|
|
|
|
EXPECT_EQ(SECSuccess,
|
|
SSL_AuthCertificateHook(ssl_fd_, &TlsAgent::ClientAuthenticated,
|
|
this));
|
|
expect_client_auth_ = true;
|
|
}
|
|
|
|
void TlsAgent::StartConnect() {
|
|
EXPECT_TRUE(EnsureTlsSetup());
|
|
|
|
SECStatus rv;
|
|
rv = SSL_ResetHandshake(ssl_fd_, role_ == SERVER ? PR_TRUE : PR_FALSE);
|
|
EXPECT_EQ(SECSuccess, rv);
|
|
SetState(STATE_CONNECTING);
|
|
}
|
|
|
|
void TlsAgent::DisableCiphersByKeyExchange(SSLKEAType kea) {
|
|
EXPECT_TRUE(EnsureTlsSetup());
|
|
|
|
for (size_t i = 0; i < SSL_NumImplementedCiphers; ++i) {
|
|
SSLCipherSuiteInfo csinfo;
|
|
|
|
SECStatus rv = SSL_GetCipherSuiteInfo(SSL_ImplementedCiphers[i],
|
|
&csinfo, sizeof(csinfo));
|
|
ASSERT_EQ(SECSuccess, rv);
|
|
|
|
if (csinfo.keaType == kea) {
|
|
rv = SSL_CipherPrefSet(ssl_fd_, SSL_ImplementedCiphers[i], PR_FALSE);
|
|
EXPECT_EQ(SECSuccess, rv);
|
|
}
|
|
}
|
|
}
|
|
|
|
void TlsAgent::SetSessionTicketsEnabled(bool en) {
|
|
EXPECT_TRUE(EnsureTlsSetup());
|
|
|
|
SECStatus rv = SSL_OptionSet(ssl_fd_, SSL_ENABLE_SESSION_TICKETS,
|
|
en ? PR_TRUE : PR_FALSE);
|
|
EXPECT_EQ(SECSuccess, rv);
|
|
}
|
|
|
|
void TlsAgent::SetSessionCacheEnabled(bool en) {
|
|
EXPECT_TRUE(EnsureTlsSetup());
|
|
|
|
SECStatus rv = SSL_OptionSet(ssl_fd_, SSL_NO_CACHE,
|
|
en ? PR_FALSE : PR_TRUE);
|
|
EXPECT_EQ(SECSuccess, rv);
|
|
}
|
|
|
|
void TlsAgent::SetVersionRange(uint16_t minver, uint16_t maxver) {
|
|
vrange_.min = minver;
|
|
vrange_.max = maxver;
|
|
|
|
if (ssl_fd_) {
|
|
SECStatus rv = SSL_VersionRangeSet(ssl_fd_, &vrange_);
|
|
EXPECT_EQ(SECSuccess, rv);
|
|
}
|
|
}
|
|
|
|
void TlsAgent::SetExpectedVersion(uint16_t version) {
|
|
expected_version_ = version;
|
|
}
|
|
|
|
void TlsAgent::SetExpectedReadError(bool err) {
|
|
expected_read_error_ = err;
|
|
}
|
|
|
|
void TlsAgent::SetSignatureAlgorithms(const SSLSignatureAndHashAlg* algorithms,
|
|
size_t count) {
|
|
EXPECT_TRUE(EnsureTlsSetup());
|
|
EXPECT_LE(count, SSL_SignatureMaxCount());
|
|
EXPECT_EQ(SECSuccess, SSL_SignaturePrefSet(ssl_fd_, algorithms,
|
|
static_cast<unsigned int>(count)));
|
|
EXPECT_EQ(SECFailure, SSL_SignaturePrefSet(ssl_fd_, algorithms, 0))
|
|
<< "setting no algorithms should fail and do nothing";
|
|
|
|
std::vector<SSLSignatureAndHashAlg> configuredAlgorithms(count);
|
|
unsigned int configuredCount;
|
|
EXPECT_EQ(SECFailure,
|
|
SSL_SignaturePrefGet(ssl_fd_, nullptr, &configuredCount, 1))
|
|
<< "get algorithms, algorithms is nullptr";
|
|
EXPECT_EQ(SECFailure,
|
|
SSL_SignaturePrefGet(ssl_fd_, &configuredAlgorithms[0],
|
|
&configuredCount, 0))
|
|
<< "get algorithms, too little space";
|
|
EXPECT_EQ(SECFailure,
|
|
SSL_SignaturePrefGet(ssl_fd_, &configuredAlgorithms[0],
|
|
nullptr, configuredAlgorithms.size()))
|
|
<< "get algorithms, algCountOut is nullptr";
|
|
|
|
EXPECT_EQ(SECSuccess,
|
|
SSL_SignaturePrefGet(ssl_fd_, &configuredAlgorithms[0],
|
|
&configuredCount,
|
|
configuredAlgorithms.size()));
|
|
// SignaturePrefSet drops unsupported algorithms silently, so the number that
|
|
// are configured might be fewer.
|
|
EXPECT_LE(configuredCount, count);
|
|
unsigned int i = 0;
|
|
for (unsigned int j = 0; j < count && i < configuredCount; ++j) {
|
|
if (i < configuredCount &&
|
|
algorithms[j].hashAlg == configuredAlgorithms[i].hashAlg &&
|
|
algorithms[j].sigAlg == configuredAlgorithms[i].sigAlg) {
|
|
++i;
|
|
}
|
|
}
|
|
EXPECT_EQ(i, configuredCount) << "algorithms in use were all set";
|
|
}
|
|
|
|
void TlsAgent::CheckKEAType(SSLKEAType type) const {
|
|
EXPECT_EQ(STATE_CONNECTED, state_);
|
|
EXPECT_EQ(type, csinfo_.keaType);
|
|
}
|
|
|
|
void TlsAgent::CheckAuthType(SSLAuthType type) const {
|
|
EXPECT_EQ(STATE_CONNECTED, state_);
|
|
EXPECT_EQ(type, csinfo_.authAlgorithm);
|
|
}
|
|
|
|
void TlsAgent::EnableFalseStart() {
|
|
EXPECT_TRUE(EnsureTlsSetup());
|
|
|
|
falsestart_enabled_ = true;
|
|
EXPECT_EQ(SECSuccess,
|
|
SSL_SetCanFalseStartCallback(ssl_fd_, CanFalseStartCallback, this));
|
|
EXPECT_EQ(SECSuccess,
|
|
SSL_OptionSet(ssl_fd_, SSL_ENABLE_FALSE_START, PR_TRUE));
|
|
}
|
|
|
|
void TlsAgent::ExpectResumption() {
|
|
expect_resumption_ = true;
|
|
}
|
|
|
|
void TlsAgent::EnableAlpn(const uint8_t* val, size_t len) {
|
|
EXPECT_TRUE(EnsureTlsSetup());
|
|
|
|
EXPECT_EQ(SECSuccess, SSL_OptionSet(ssl_fd_, SSL_ENABLE_ALPN, PR_TRUE));
|
|
EXPECT_EQ(SECSuccess, SSL_SetNextProtoNego(ssl_fd_, val, len));
|
|
}
|
|
|
|
void TlsAgent::CheckAlpn(SSLNextProtoState expected_state,
|
|
const std::string& expected) const {
|
|
SSLNextProtoState state;
|
|
char chosen[10];
|
|
unsigned int chosen_len;
|
|
SECStatus rv = SSL_GetNextProto(ssl_fd_, &state,
|
|
reinterpret_cast<unsigned char*>(chosen),
|
|
&chosen_len, sizeof(chosen));
|
|
EXPECT_EQ(SECSuccess, rv);
|
|
EXPECT_EQ(expected_state, state);
|
|
EXPECT_EQ(expected, std::string(chosen, chosen_len));
|
|
}
|
|
|
|
void TlsAgent::EnableSrtp() {
|
|
EXPECT_TRUE(EnsureTlsSetup());
|
|
const uint16_t ciphers[] = {
|
|
SRTP_AES128_CM_HMAC_SHA1_80, SRTP_AES128_CM_HMAC_SHA1_32
|
|
};
|
|
EXPECT_EQ(SECSuccess, SSL_SetSRTPCiphers(ssl_fd_, ciphers,
|
|
PR_ARRAY_SIZE(ciphers)));
|
|
}
|
|
|
|
void TlsAgent::CheckSrtp() const {
|
|
uint16_t actual;
|
|
EXPECT_EQ(SECSuccess, SSL_GetSRTPCipher(ssl_fd_, &actual));
|
|
EXPECT_EQ(SRTP_AES128_CM_HMAC_SHA1_80, actual);
|
|
}
|
|
|
|
void TlsAgent::CheckErrorCode(int32_t expected) const {
|
|
EXPECT_EQ(STATE_ERROR, state_);
|
|
EXPECT_EQ(expected, error_code_);
|
|
}
|
|
|
|
void TlsAgent::CheckPreliminaryInfo() {
|
|
SSLPreliminaryChannelInfo info;
|
|
EXPECT_EQ(SECSuccess,
|
|
SSL_GetPreliminaryChannelInfo(ssl_fd_, &info, sizeof(info)));
|
|
EXPECT_TRUE(info.valuesSet & ssl_preinfo_version);
|
|
EXPECT_TRUE(info.valuesSet & ssl_preinfo_cipher_suite);
|
|
|
|
// A version of 0 is invalid and indicates no expectation. This value is
|
|
// initialized to 0 so that tests that don't explicitly set an expected
|
|
// version can negotiate a version.
|
|
if (!expected_version_) {
|
|
expected_version_ = info.protocolVersion;
|
|
}
|
|
EXPECT_EQ(expected_version_, info.protocolVersion);
|
|
|
|
// As with the version; 0 is the null cipher suite (and also invalid).
|
|
if (!expected_cipher_suite_) {
|
|
expected_cipher_suite_ = info.cipherSuite;
|
|
}
|
|
EXPECT_EQ(expected_cipher_suite_, info.cipherSuite);
|
|
}
|
|
|
|
// Check that all the expected callbacks have been called.
|
|
void TlsAgent::CheckCallbacks() const {
|
|
// If false start happens, the handshake is reported as being complete at the
|
|
// point that false start happens.
|
|
if (expect_resumption_ || !falsestart_enabled_) {
|
|
EXPECT_TRUE(handshake_callback_called_);
|
|
}
|
|
|
|
// These callbacks shouldn't fire if we are resuming.
|
|
if (role_ == SERVER) {
|
|
EXPECT_EQ(!expect_resumption_, sni_hook_called_);
|
|
} else {
|
|
EXPECT_EQ(!expect_resumption_, auth_certificate_hook_called_);
|
|
// Note that this isn't unconditionally called, even with false start on.
|
|
// But the callback is only skipped if a cipher that is ridiculously weak
|
|
// (80 bits) is chosen. Don't test that: plan to remove bad ciphers.
|
|
EXPECT_EQ(falsestart_enabled_ && !expect_resumption_,
|
|
can_falsestart_hook_called_);
|
|
}
|
|
}
|
|
|
|
void TlsAgent::Connected() {
|
|
LOG("Handshake success");
|
|
CheckCallbacks();
|
|
|
|
SECStatus rv = SSL_GetChannelInfo(ssl_fd_, &info_, sizeof(info_));
|
|
EXPECT_EQ(SECSuccess, rv);
|
|
|
|
// Preliminary values are exposed through callbacks during the handshake.
|
|
// If either expected values were set or the callbacks were called, check
|
|
// that the final values are correct.
|
|
EXPECT_EQ(expected_version_, info_.protocolVersion);
|
|
EXPECT_EQ(expected_cipher_suite_, info_.cipherSuite);
|
|
|
|
rv = SSL_GetCipherSuiteInfo(info_.cipherSuite, &csinfo_, sizeof(csinfo_));
|
|
EXPECT_EQ(SECSuccess, rv);
|
|
|
|
SetState(STATE_CONNECTED);
|
|
}
|
|
|
|
void TlsAgent::EnableExtendedMasterSecret() {
|
|
ASSERT_TRUE(EnsureTlsSetup());
|
|
|
|
SECStatus rv = SSL_OptionSet(ssl_fd_,
|
|
SSL_ENABLE_EXTENDED_MASTER_SECRET,
|
|
PR_TRUE);
|
|
|
|
ASSERT_EQ(SECSuccess, rv);
|
|
}
|
|
|
|
void TlsAgent::CheckExtendedMasterSecret(bool expected) {
|
|
ASSERT_EQ(expected, info_.extendedMasterSecretUsed != PR_FALSE)
|
|
<< "unexpected extended master secret state for " << name_;
|
|
}
|
|
|
|
void TlsAgent::DisableRollbackDetection() {
|
|
ASSERT_TRUE(EnsureTlsSetup());
|
|
|
|
SECStatus rv = SSL_OptionSet(ssl_fd_,
|
|
SSL_ROLLBACK_DETECTION,
|
|
PR_FALSE);
|
|
|
|
ASSERT_EQ(SECSuccess, rv);
|
|
}
|
|
|
|
void TlsAgent::Handshake() {
|
|
SECStatus rv = SSL_ForceHandshake(ssl_fd_);
|
|
if (rv == SECSuccess) {
|
|
Connected();
|
|
|
|
Poller::Instance()->Wait(READABLE_EVENT, adapter_, this,
|
|
&TlsAgent::ReadableCallback);
|
|
|
|
return;
|
|
}
|
|
|
|
int32_t err = PR_GetError();
|
|
switch (err) {
|
|
case PR_WOULD_BLOCK_ERROR:
|
|
LOG("Would have blocked");
|
|
// TODO(ekr@rtfm.com): set DTLS timeouts
|
|
Poller::Instance()->Wait(READABLE_EVENT, adapter_, this,
|
|
&TlsAgent::ReadableCallback);
|
|
return;
|
|
break;
|
|
|
|
// TODO(ekr@rtfm.com): needs special case for DTLS
|
|
case SSL_ERROR_RX_MALFORMED_HANDSHAKE:
|
|
default:
|
|
if (IS_SSL_ERROR(err)) {
|
|
LOG("Handshake failed with SSL error " << err - SSL_ERROR_BASE);
|
|
} else {
|
|
LOG("Handshake failed with error " << err);
|
|
}
|
|
error_code_ = err;
|
|
SetState(STATE_ERROR);
|
|
return;
|
|
}
|
|
}
|
|
|
|
void TlsAgent::PrepareForRenegotiate() {
|
|
EXPECT_EQ(STATE_CONNECTED, state_);
|
|
|
|
SetState(STATE_CONNECTING);
|
|
}
|
|
|
|
void TlsAgent::StartRenegotiate() {
|
|
PrepareForRenegotiate();
|
|
|
|
SECStatus rv = SSL_ReHandshake(ssl_fd_, PR_TRUE);
|
|
EXPECT_EQ(SECSuccess, rv);
|
|
}
|
|
|
|
void TlsAgent::SendData(size_t bytes, size_t blocksize) {
|
|
uint8_t block[4096];
|
|
|
|
ASSERT_LT(blocksize, sizeof(block));
|
|
|
|
while(bytes) {
|
|
size_t tosend = std::min(blocksize, bytes);
|
|
|
|
for(size_t i = 0; i < tosend; ++i) {
|
|
block[i] = 0xff & send_ctr_;
|
|
++send_ctr_;
|
|
}
|
|
|
|
LOG("Writing " << tosend << " bytes");
|
|
int32_t rv = PR_Write(ssl_fd_, block, tosend);
|
|
ASSERT_EQ(tosend, static_cast<size_t>(rv));
|
|
|
|
bytes -= tosend;
|
|
}
|
|
}
|
|
|
|
void TlsAgent::ReadBytes() {
|
|
uint8_t block[1024];
|
|
|
|
LOG("Reading application data from socket");
|
|
|
|
int32_t rv = PR_Read(ssl_fd_, block, sizeof(block));
|
|
|
|
int32_t err = PR_GetError();
|
|
if (err != PR_WOULD_BLOCK_ERROR) {
|
|
if (expected_read_error_) {
|
|
error_code_ = err;
|
|
} else {
|
|
ASSERT_LE(0, rv);
|
|
size_t count = static_cast<size_t>(rv);
|
|
LOG("Read " << count << " bytes");
|
|
for (size_t i = 0; i < count; ++i) {
|
|
ASSERT_EQ(recv_ctr_ & 0xff, block[i]);
|
|
recv_ctr_++;
|
|
}
|
|
}
|
|
}
|
|
|
|
Poller::Instance()->Wait(READABLE_EVENT, adapter_, this,
|
|
&TlsAgent::ReadableCallback);
|
|
}
|
|
|
|
void TlsAgent::ResetSentBytes() {
|
|
send_ctr_ = 0;
|
|
}
|
|
|
|
void TlsAgent::ConfigureSessionCache(SessionResumptionMode mode) {
|
|
EXPECT_TRUE(EnsureTlsSetup());
|
|
|
|
SECStatus rv = SSL_OptionSet(ssl_fd_,
|
|
SSL_NO_CACHE,
|
|
mode & RESUME_SESSIONID ?
|
|
PR_FALSE : PR_TRUE);
|
|
EXPECT_EQ(SECSuccess, rv);
|
|
|
|
rv = SSL_OptionSet(ssl_fd_,
|
|
SSL_ENABLE_SESSION_TICKETS,
|
|
mode & RESUME_TICKET ?
|
|
PR_TRUE : PR_FALSE);
|
|
EXPECT_EQ(SECSuccess, rv);
|
|
}
|
|
|
|
static const std::string kTlsRolesAllArr[] = {"CLIENT", "SERVER"};
|
|
::testing::internal::ParamGenerator<std::string>
|
|
TlsAgentTestBase::kTlsRolesAll = ::testing::ValuesIn(kTlsRolesAllArr);
|
|
|
|
void TlsAgentTestBase::Init() {
|
|
agent_ = new TlsAgent(
|
|
role_ == TlsAgent::CLIENT ? "client" : "server",
|
|
role_, mode_, kea_);
|
|
agent_->Init();
|
|
fd_ = DummyPrSocket::CreateFD("dummy", mode_);
|
|
agent_->adapter()->SetPeer(
|
|
DummyPrSocket::GetAdapter(fd_));
|
|
agent_->StartConnect();
|
|
}
|
|
|
|
void TlsAgentTestBase::EnsureInit() {
|
|
if (!agent_) {
|
|
Init();
|
|
}
|
|
}
|
|
|
|
void TlsAgentTestBase::ProcessMessage(const DataBuffer& buffer,
|
|
TlsAgent::State expected_state,
|
|
int32_t error_code) {
|
|
EnsureInit();
|
|
agent_->adapter()->PacketReceived(buffer);
|
|
agent_->Handshake();
|
|
|
|
ASSERT_EQ(expected_state, agent_->state());
|
|
|
|
if (expected_state == TlsAgent::STATE_ERROR) {
|
|
ASSERT_EQ(error_code, agent_->error_code());
|
|
}
|
|
}
|
|
|
|
} // namespace nss_test
|