Improved remote connection error handling (#376)

* Improved remote connection error handling

* Improved error handling by adding a magic string to each message

* Interface comment update

* Interface comment update

* Improve error messages

* Clients send the magic word to authenticate

Co-authored-by: Daniel Markstedt <markstedt@gmail.com>
This commit is contained in:
Uwe Seimet 2021-10-26 01:04:10 +02:00 committed by GitHub
parent f0a7deb361
commit cc1783c1cd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 35 additions and 13 deletions

View File

@ -235,6 +235,8 @@ def send_over_socket(s, payload):
Reads data from socket in 2048 bytes chunks until all data is received. Reads data from socket in 2048 bytes chunks until all data is received.
""" """
# Sending the magic word "RASCSI" to authenticate with the server
s.send(b"RASCSI")
# Prepending a little endian 32bit header with the message size # Prepending a little endian 32bit header with the message size
s.send(pack("<i", len(payload))) s.send(pack("<i", len(payload)))
s.send(payload) s.send(payload)
@ -250,16 +252,17 @@ def send_over_socket(s, payload):
while bytes_recvd < response_length: while bytes_recvd < response_length:
chunk = s.recv(min(response_length - bytes_recvd, 2048)) chunk = s.recv(min(response_length - bytes_recvd, 2048))
if chunk == b'': if chunk == b'':
exit("Read an empty chunk from the socket. \ exit("Socket connection has dropped unexpectedly. "
Socket connection has dropped unexpectedly. \ "RaSCSI may have crashed."
RaSCSI may have has crashed.") )
chunks.append(chunk) chunks.append(chunk)
bytes_recvd = bytes_recvd + len(chunk) bytes_recvd = bytes_recvd + len(chunk)
response_message = b''.join(chunks) response_message = b''.join(chunks)
return response_message return response_message
else: else:
exit("The response from RaSCSI did not contain a protobuf header. \ exit("The response from RaSCSI did not contain a protobuf header. "
RaSCSI may have crashed.") "RaSCSI may have crashed."
)
def formatted_output(): def formatted_output():

View File

@ -70,12 +70,12 @@ void protobuf_util::SerializeMessage(int fd, const google::protobuf::Message& me
// Write the size of the protobuf data as a header // Write the size of the protobuf data as a header
int32_t size = data.length(); int32_t size = data.length();
if (write(fd, &size, sizeof(size)) != sizeof(size)) { if (write(fd, &size, sizeof(size)) != sizeof(size)) {
throw io_exception("Can't write protobuf header"); throw io_exception("Can't write protobuf message header");
} }
// Write the actual protobuf data // Write the actual protobuf data
if (write(fd, data.data(), size) != size) { if (write(fd, data.data(), size) != size) {
throw io_exception("Can't write protobuf data"); throw io_exception("Can't write protobuf message data");
} }
} }
@ -83,17 +83,20 @@ void protobuf_util::DeserializeMessage(int fd, google::protobuf::Message& messag
{ {
// Read the header with the size of the protobuf data // Read the header with the size of the protobuf data
uint8_t header_buf[4]; uint8_t header_buf[4];
int bytes_read = ReadNBytes(fd, header_buf, 4); int bytes_read = ReadNBytes(fd, header_buf, sizeof(header_buf));
if (bytes_read < 4) { if (bytes_read < (int)sizeof(header_buf)) {
return; return;
} }
int32_t size = (header_buf[3] << 24) + (header_buf[2] << 16) + (header_buf[1] << 8) + header_buf[0]; int32_t size = (header_buf[3] << 24) + (header_buf[2] << 16) + (header_buf[1] << 8) + header_buf[0];
if (size <= 0) {
throw io_exception("Broken protobuf message header");
}
// Read the binary protobuf data // Read the binary protobuf data
uint8_t data_buf[size]; uint8_t data_buf[size];
bytes_read = ReadNBytes(fd, data_buf, size); bytes_read = ReadNBytes(fd, data_buf, size);
if (bytes_read < size) { if (bytes_read < size) {
throw io_exception("Missing protobuf data"); throw io_exception("Missing protobuf message data");
} }
// Create protobuf message // Create protobuf message

View File

@ -1291,6 +1291,16 @@ static void *MonThread(void *param)
throw io_exception("accept() failed"); throw io_exception("accept() failed");
} }
// Read magic string
char magic[6];
int bytes_read = ReadNBytes(fd, (uint8_t *)magic, sizeof(magic));
if (!bytes_read) {
continue;
}
if (bytes_read != sizeof(magic) || strncmp(magic, "RASCSI", sizeof(magic))) {
throw io_exception("Invalid magic");
}
// Fetch the command // Fetch the command
PbCommand command; PbCommand command;
DeserializeMessage(fd, command); DeserializeMessage(fd, command);

View File

@ -1,6 +1,6 @@
// //
// Each rascsi remote interface message is preceded by a little endian 32 bit header, // Each rascsi message sent to the rascsi server is preceded by the magic string "RASCSI".
// which contains the protobuf message size. // A message starts with a little endian 32 bit header which contains the protobuf message size.
// Unless explicitly specified the order of repeated data returned is undefined. // Unless explicitly specified the order of repeated data returned is undefined.
// //

View File

@ -60,7 +60,11 @@ void RasctlCommands::SendCommand()
throw io_exception(error.str()); throw io_exception(error.str());
} }
SerializeMessage(fd, command); if (write(fd, "RASCSI", 6) != 6) {
throw io_exception("Can't write magic");
}
SerializeMessage(fd, command);
} }
catch(const io_exception& e) { catch(const io_exception& e) {
cerr << "Error: " << e.getmsg() << endl; cerr << "Error: " << e.getmsg() << endl;

View File

@ -409,6 +409,8 @@ def send_over_socket(s, payload):
""" """
from struct import pack, unpack from struct import pack, unpack
# Sending the magic word "RASCSI" to authenticate with the server
s.send(b"RASCSI")
# Prepending a little endian 32bit header with the message size # Prepending a little endian 32bit header with the message size
s.send(pack("<i", len(payload))) s.send(pack("<i", len(payload)))
s.send(payload) s.send(payload)