tenfourfox/security/nss/external_tests/ssl_gtest/test_io.cc
Cameron Kaiser c9b2922b70 hello FPR
2017-04-19 00:56:45 -07:00

494 lines
12 KiB
C++

/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
/* vim: set ts=2 et sw=2 tw=80: */
/* This Source Code Form is subject to the terms of the Mozilla Public
* License, v. 2.0. If a copy of the MPL was not distributed with this file,
* You can obtain one at http://mozilla.org/MPL/2.0/. */
#include "test_io.h"
#include <algorithm>
#include <cassert>
#include <iostream>
#include <memory>
#include "prerror.h"
#include "prlog.h"
#include "prthread.h"
#include "databuffer.h"
namespace nss_test {
static PRDescIdentity test_fd_identity = PR_INVALID_IO_LAYER;
#define UNIMPLEMENTED() \
std::cerr << "Call to unimplemented function " \
<< __FUNCTION__ << std::endl; \
PR_ASSERT(PR_FALSE); \
PR_SetError(PR_NOT_IMPLEMENTED_ERROR, 0)
#define LOG(a) std::cerr << name_ << ": " << a << std::endl;
class Packet : public DataBuffer {
public:
Packet(const DataBuffer& buf) : DataBuffer(buf), offset_(0) {}
void Advance(size_t delta) {
PR_ASSERT(offset_ + delta <= len());
offset_ = std::min(len(), offset_ + delta);
}
size_t offset() const { return offset_; }
size_t remaining() const { return len() - offset_; }
private:
size_t offset_;
};
// Implementation of NSPR methods
static PRStatus DummyClose(PRFileDesc *f) {
f->secret = nullptr;
return PR_SUCCESS;
}
static int32_t DummyRead(PRFileDesc *f, void *buf, int32_t length) {
DummyPrSocket *io = reinterpret_cast<DummyPrSocket *>(f->secret);
return io->Read(buf, length);
}
static int32_t DummyWrite(PRFileDesc *f, const void *buf, int32_t length) {
DummyPrSocket *io = reinterpret_cast<DummyPrSocket *>(f->secret);
return io->Write(buf, length);
}
static int32_t DummyAvailable(PRFileDesc *f) {
UNIMPLEMENTED();
return -1;
}
int64_t DummyAvailable64(PRFileDesc *f) {
UNIMPLEMENTED();
return -1;
}
static PRStatus DummySync(PRFileDesc *f) {
UNIMPLEMENTED();
return PR_FAILURE;
}
static int32_t DummySeek(PRFileDesc *f, int32_t offset, PRSeekWhence how) {
UNIMPLEMENTED();
return -1;
}
static int64_t DummySeek64(PRFileDesc *f, int64_t offset, PRSeekWhence how) {
UNIMPLEMENTED();
return -1;
}
static PRStatus DummyFileInfo(PRFileDesc *f, PRFileInfo *info) {
UNIMPLEMENTED();
return PR_FAILURE;
}
static PRStatus DummyFileInfo64(PRFileDesc *f, PRFileInfo64 *info) {
UNIMPLEMENTED();
return PR_FAILURE;
}
static int32_t DummyWritev(PRFileDesc *f, const PRIOVec *iov, int32_t iov_size,
PRIntervalTime to) {
UNIMPLEMENTED();
return -1;
}
static PRStatus DummyConnect(PRFileDesc *f, const PRNetAddr *addr,
PRIntervalTime to) {
UNIMPLEMENTED();
return PR_FAILURE;
}
static PRFileDesc *DummyAccept(PRFileDesc *sd, PRNetAddr *addr,
PRIntervalTime to) {
UNIMPLEMENTED();
return nullptr;
}
static PRStatus DummyBind(PRFileDesc *f, const PRNetAddr *addr) {
UNIMPLEMENTED();
return PR_FAILURE;
}
static PRStatus DummyListen(PRFileDesc *f, int32_t depth) {
UNIMPLEMENTED();
return PR_FAILURE;
}
static PRStatus DummyShutdown(PRFileDesc *f, int32_t how) {
DummyPrSocket *io = reinterpret_cast<DummyPrSocket *>(f->secret);
io->Reset();
return PR_SUCCESS;
}
// This function does not support peek.
static int32_t DummyRecv(PRFileDesc *f, void *buf, int32_t buflen,
int32_t flags, PRIntervalTime to) {
PR_ASSERT(flags == 0);
if (flags != 0) {
PR_SetError(PR_NOT_IMPLEMENTED_ERROR, 0);
return -1;
}
DummyPrSocket *io = reinterpret_cast<DummyPrSocket *>(f->secret);
if (io->mode() == DGRAM) {
return io->Recv(buf, buflen);
} else {
return io->Read(buf, buflen);
}
}
// Note: this is always nonblocking and assumes a zero timeout.
static int32_t DummySend(PRFileDesc *f, const void *buf, int32_t amount,
int32_t flags, PRIntervalTime to) {
int32_t written = DummyWrite(f, buf, amount);
return written;
}
static int32_t DummyRecvfrom(PRFileDesc *f, void *buf, int32_t amount,
int32_t flags, PRNetAddr *addr,
PRIntervalTime to) {
UNIMPLEMENTED();
return -1;
}
static int32_t DummySendto(PRFileDesc *f, const void *buf, int32_t amount,
int32_t flags, const PRNetAddr *addr,
PRIntervalTime to) {
UNIMPLEMENTED();
return -1;
}
static int16_t DummyPoll(PRFileDesc *f, int16_t in_flags, int16_t *out_flags) {
UNIMPLEMENTED();
return -1;
}
static int32_t DummyAcceptRead(PRFileDesc *sd, PRFileDesc **nd,
PRNetAddr **raddr, void *buf, int32_t amount,
PRIntervalTime t) {
UNIMPLEMENTED();
return -1;
}
static int32_t DummyTransmitFile(PRFileDesc *sd, PRFileDesc *f,
const void *headers, int32_t hlen,
PRTransmitFileFlags flags, PRIntervalTime t) {
UNIMPLEMENTED();
return -1;
}
static PRStatus DummyGetpeername(PRFileDesc *f, PRNetAddr *addr) {
// TODO: Modify to return unique names for each channel
// somehow, as opposed to always the same static address. The current
// implementation messes up the session cache, which is why it's off
// elsewhere
addr->inet.family = PR_AF_INET;
addr->inet.port = 0;
addr->inet.ip = 0;
return PR_SUCCESS;
}
static PRStatus DummyGetsockname(PRFileDesc *f, PRNetAddr *addr) {
UNIMPLEMENTED();
return PR_FAILURE;
}
static PRStatus DummyGetsockoption(PRFileDesc *f, PRSocketOptionData *opt) {
switch (opt->option) {
case PR_SockOpt_Nonblocking:
opt->value.non_blocking = PR_TRUE;
return PR_SUCCESS;
default:
UNIMPLEMENTED();
break;
}
return PR_FAILURE;
}
// Imitate setting socket options. These are mostly noops.
static PRStatus DummySetsockoption(PRFileDesc *f,
const PRSocketOptionData *opt) {
switch (opt->option) {
case PR_SockOpt_Nonblocking:
return PR_SUCCESS;
case PR_SockOpt_NoDelay:
return PR_SUCCESS;
default:
UNIMPLEMENTED();
break;
}
return PR_FAILURE;
}
static int32_t DummySendfile(PRFileDesc *out, PRSendFileData *in,
PRTransmitFileFlags flags, PRIntervalTime to) {
UNIMPLEMENTED();
return -1;
}
static PRStatus DummyConnectContinue(PRFileDesc *f, int16_t flags) {
UNIMPLEMENTED();
return PR_FAILURE;
}
static int32_t DummyReserved(PRFileDesc *f) {
UNIMPLEMENTED();
return -1;
}
DummyPrSocket::~DummyPrSocket() {
Reset();
}
void DummyPrSocket::Reset() {
delete filter_;
peer_ = nullptr;
while (!input_.empty())
{
Packet* front = input_.front();
input_.pop();
delete front;
}
}
static const struct PRIOMethods DummyMethods = {
PR_DESC_LAYERED, DummyClose, DummyRead,
DummyWrite, DummyAvailable, DummyAvailable64,
DummySync, DummySeek, DummySeek64,
DummyFileInfo, DummyFileInfo64, DummyWritev,
DummyConnect, DummyAccept, DummyBind,
DummyListen, DummyShutdown, DummyRecv,
DummySend, DummyRecvfrom, DummySendto,
DummyPoll, DummyAcceptRead, DummyTransmitFile,
DummyGetsockname, DummyGetpeername, DummyReserved,
DummyReserved, DummyGetsockoption, DummySetsockoption,
DummySendfile, DummyConnectContinue, DummyReserved,
DummyReserved, DummyReserved, DummyReserved};
PRFileDesc *DummyPrSocket::CreateFD(const std::string &name, Mode mode) {
if (test_fd_identity == PR_INVALID_IO_LAYER) {
test_fd_identity = PR_GetUniqueIdentity("testtransportadapter");
}
PRFileDesc *fd = (PR_CreateIOLayerStub(test_fd_identity, &DummyMethods));
fd->secret = reinterpret_cast<PRFilePrivate *>(new DummyPrSocket(name, mode));
return fd;
}
DummyPrSocket *DummyPrSocket::GetAdapter(PRFileDesc *fd) {
return reinterpret_cast<DummyPrSocket *>(fd->secret);
}
void DummyPrSocket::PacketReceived(const DataBuffer& packet) {
input_.push(new Packet(packet));
}
int32_t DummyPrSocket::Read(void *data, int32_t len) {
PR_ASSERT(mode_ == STREAM);
if (mode_ != STREAM) {
PR_SetError(PR_INVALID_METHOD_ERROR, 0);
return -1;
}
if (input_.empty()) {
LOG("Read --> wouldblock " << len);
PR_SetError(PR_WOULD_BLOCK_ERROR, 0);
return -1;
}
Packet *front = input_.front();
size_t to_read = std::min(static_cast<size_t>(len),
front->len() - front->offset());
memcpy(data, static_cast<const void*>(front->data() + front->offset()),
to_read);
front->Advance(to_read);
if (!front->remaining()) {
input_.pop();
delete front;
}
return static_cast<int32_t>(to_read);
}
int32_t DummyPrSocket::Recv(void *buf, int32_t buflen) {
if (input_.empty()) {
PR_SetError(PR_WOULD_BLOCK_ERROR, 0);
return -1;
}
Packet *front = input_.front();
if (static_cast<size_t>(buflen) < front->len()) {
PR_ASSERT(false);
PR_SetError(PR_BUFFER_OVERFLOW_ERROR, 0);
return -1;
}
size_t count = front->len();
memcpy(buf, front->data(), count);
input_.pop();
delete front;
return static_cast<int32_t>(count);
}
int32_t DummyPrSocket::Write(const void *buf, int32_t length) {
if (!peer_) {
PR_SetError(PR_IO_ERROR, 0);
return -1;
}
DataBuffer packet(static_cast<const uint8_t*>(buf),
static_cast<size_t>(length));
DataBuffer filtered;
if (filter_ && filter_->Filter(packet, &filtered)) {
LOG("Filtered packet: " << filtered);
peer_->PacketReceived(filtered);
} else {
peer_->PacketReceived(packet);
}
// libssl can't handle it if this reports something other than the length
// of what was passed in (or less, but we're not doing partial writes).
return static_cast<int32_t>(packet.len());
}
Poller *Poller::instance;
Poller *Poller::Instance() {
if (!instance) instance = new Poller();
return instance;
}
void Poller::Shutdown() {
delete instance;
instance = nullptr;
}
void Poller::Wait(Event event, DummyPrSocket *adapter, PollTarget *target,
PollCallback cb) {
auto it = waiters_.find(adapter);
Waiter *waiter;
if (it == waiters_.end()) {
waiter = new Waiter(adapter);
} else {
waiter = it->second;
}
assert(event < TIMER_EVENT);
if (event >= TIMER_EVENT) return;
waiter->targets_[event] = target;
waiter->callbacks_[event] = cb;
waiters_[adapter] = waiter;
}
void Poller::Cancel(Event event, DummyPrSocket *adapter) {
auto it = waiters_.find(adapter);
Waiter *waiter;
if (it == waiters_.end()) {
return;
}
waiter = it->second;
waiter->targets_[event] = nullptr;
waiter->callbacks_[event] = nullptr;
// Clean up if there are no callbacks.
for (size_t i=0; i<TIMER_EVENT; ++i) {
if (waiter->callbacks_[i])
return;
}
delete waiter;
waiters_.erase(adapter);
}
void Poller::SetTimer(uint32_t timer_ms, PollTarget *target, PollCallback cb,
Timer **timer) {
Timer *t = new Timer(PR_Now() + timer_ms * 1000, target, cb);
timers_.push(t);
if (timer) *timer = t;
}
bool Poller::Poll() {
std::cerr << "Poll()\n";
PRIntervalTime timeout = PR_INTERVAL_NO_TIMEOUT;
PRTime now = PR_Now();
bool fired = false;
// Figure out the timer for the select.
if (!timers_.empty()) {
Timer *first_timer = timers_.top();
if (now >= first_timer->deadline_) {
// Timer expired.
timeout = PR_INTERVAL_NO_WAIT;
} else {
timeout =
PR_MillisecondsToInterval((first_timer->deadline_ - now) / 1000);
}
}
for (auto it = waiters_.begin(); it != waiters_.end(); ++it) {
Waiter *waiter = it->second;
if (waiter->callbacks_[READABLE_EVENT]) {
if (waiter->io_->readable()) {
PollCallback callback = waiter->callbacks_[READABLE_EVENT];
PollTarget *target = waiter->targets_[READABLE_EVENT];
waiter->callbacks_[READABLE_EVENT] = nullptr;
waiter->targets_[READABLE_EVENT] = nullptr;
callback(target, READABLE_EVENT);
fired = true;
}
}
}
if (fired) timeout = PR_INTERVAL_NO_WAIT;
// Can't wait forever and also have nothing readable now.
if (timeout == PR_INTERVAL_NO_TIMEOUT) return false;
// Sleep.
if (timeout != PR_INTERVAL_NO_WAIT) {
PR_Sleep(timeout);
}
// Now process anything that timed out.
now = PR_Now();
while (!timers_.empty()) {
if (now < timers_.top()->deadline_) break;
Timer *timer = timers_.top();
timers_.pop();
if (timer->callback_) {
timer->callback_(timer->target_, TIMER_EVENT);
}
delete timer;
}
return true;
}
} // namespace nss_test