765 lines
26 KiB
C++
765 lines
26 KiB
C++
/*
|
|
MIT License
|
|
|
|
Copyright (c) 2020 Meng Rao <raomeng1@gmail.com>
|
|
|
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
of this software and associated documentation files (the "Software"), to deal
|
|
in the Software without restriction, including without limitation the rights
|
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
copies of the Software, and to permit persons to whom the Software is
|
|
furnished to do so, subject to the following conditions:
|
|
|
|
The above copyright notice and this permission notice shall be included in all
|
|
copies or substantial portions of the Software.
|
|
|
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
SOFTWARE.
|
|
*/
|
|
#pragma once
|
|
#include <unistd.h>
|
|
#include <fcntl.h>
|
|
#include <arpa/inet.h>
|
|
#include <netinet/tcp.h>
|
|
#include <string.h>
|
|
#include <limits>
|
|
#include <memory>
|
|
#include <stdio.h>
|
|
#include <errno.h>
|
|
|
|
namespace websocket {
|
|
|
|
template<uint32_t RecvBufSize>
|
|
class SocketTcpConnection
|
|
{
|
|
public:
|
|
~SocketTcpConnection() { close("destruct"); }
|
|
|
|
const char* getLastError() { return last_error_; };
|
|
|
|
bool isConnected() { return fd_ >= 0; }
|
|
|
|
bool connect(const char* server_ip, uint16_t server_port) {
|
|
int fd = socket(AF_INET, SOCK_STREAM, 0);
|
|
if (fd < 0) {
|
|
saveError("socket error", true);
|
|
return false;
|
|
}
|
|
struct sockaddr_in server_addr;
|
|
server_addr.sin_family = AF_INET;
|
|
inet_pton(AF_INET, server_ip, &(server_addr.sin_addr));
|
|
server_addr.sin_port = htons(server_port);
|
|
bzero(&(server_addr.sin_zero), 8);
|
|
if (::connect(fd, (struct sockaddr*)&server_addr, sizeof(server_addr)) < 0) {
|
|
saveError("connect error", true);
|
|
::close(fd);
|
|
return false;
|
|
}
|
|
return open(fd);
|
|
}
|
|
|
|
bool getPeername(struct sockaddr_in& addr) {
|
|
socklen_t addr_len = sizeof(addr);
|
|
return ::getpeername(fd_, (struct sockaddr*)&addr, &addr_len) == 0;
|
|
}
|
|
|
|
void close(const char* reason, bool check_errno = false) {
|
|
if (fd_ >= 0) {
|
|
saveError(reason, check_errno);
|
|
::close(fd_);
|
|
fd_ = -1;
|
|
}
|
|
}
|
|
|
|
bool write(const uint8_t* data, uint32_t size, bool more = false) {
|
|
int flags = MSG_NOSIGNAL;
|
|
if (more) flags |= MSG_MORE;
|
|
do {
|
|
int sent = ::send(fd_, data, size, flags);
|
|
if (sent < 0) {
|
|
if (errno != EAGAIN) {
|
|
close("send error", true);
|
|
return false;
|
|
}
|
|
continue;
|
|
}
|
|
data += sent;
|
|
size -= sent;
|
|
} while (size != 0);
|
|
return true;
|
|
}
|
|
|
|
template<typename Handler>
|
|
bool read(Handler handler) {
|
|
int ret = ::read(fd_, recvbuf_ + tail_, RecvBufSize - tail_);
|
|
if (ret <= 0) {
|
|
if (ret < 0 && errno == EAGAIN) return false;
|
|
if (ret < 0) {
|
|
close("read error", true);
|
|
}
|
|
else {
|
|
close("remote close");
|
|
}
|
|
return false;
|
|
}
|
|
tail_ += ret;
|
|
|
|
uint32_t remaining = handler(recvbuf_ + head_, tail_ - head_);
|
|
if (remaining == 0) {
|
|
head_ = tail_ = 0;
|
|
}
|
|
else {
|
|
head_ = tail_ - remaining;
|
|
if (head_ >= RecvBufSize / 2) {
|
|
memcpy(recvbuf_, recvbuf_ + head_, remaining);
|
|
head_ = 0;
|
|
tail_ = remaining;
|
|
}
|
|
else if (tail_ == RecvBufSize) {
|
|
close("recv buf full");
|
|
}
|
|
}
|
|
return true;
|
|
}
|
|
|
|
protected:
|
|
template<uint32_t>
|
|
friend class SocketTcpServer;
|
|
|
|
bool open(int fd) {
|
|
fd_ = fd;
|
|
head_ = tail_ = 0;
|
|
|
|
int flags = fcntl(fd_, F_GETFL, 0);
|
|
if (fcntl(fd_, F_SETFL, flags | O_NONBLOCK) < 0) {
|
|
close("fcntl O_NONBLOCK error", true);
|
|
return false;
|
|
}
|
|
|
|
int yes = 1;
|
|
if (setsockopt(fd_, IPPROTO_TCP, TCP_NODELAY, &yes, sizeof(yes)) < 0) {
|
|
close("setsockopt TCP_NODELAY error", true);
|
|
return false;
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
void saveError(const char* msg, bool check_errno) {
|
|
snprintf(last_error_, sizeof(last_error_), "%s %s", msg, check_errno ? (const char*)strerror(errno) : "");
|
|
}
|
|
|
|
int fd_ = -1;
|
|
uint32_t head_;
|
|
uint32_t tail_;
|
|
char recvbuf_[RecvBufSize];
|
|
char last_error_[64] = "";
|
|
};
|
|
|
|
template<uint32_t RecvBufSize = 4096>
|
|
class SocketTcpServer
|
|
{
|
|
public:
|
|
using TcpConnection = SocketTcpConnection<RecvBufSize>;
|
|
|
|
bool init(const char* interface, const char* server_ip, uint16_t server_port) {
|
|
listenfd_ = socket(AF_INET, SOCK_STREAM, 0);
|
|
if (listenfd_ < 0) {
|
|
saveError("socket error");
|
|
return false;
|
|
}
|
|
|
|
int flags = fcntl(listenfd_, F_GETFL, 0);
|
|
if (fcntl(listenfd_, F_SETFL, flags | O_NONBLOCK) < 0) {
|
|
close("fcntl O_NONBLOCK error");
|
|
return false;
|
|
}
|
|
|
|
int yes = 1;
|
|
if (setsockopt(listenfd_, SOL_SOCKET, SO_REUSEADDR, &yes, sizeof(yes)) < 0) {
|
|
close("setsockopt SO_REUSEADDR error");
|
|
return false;
|
|
}
|
|
|
|
struct sockaddr_in local_addr;
|
|
local_addr.sin_family = AF_INET;
|
|
inet_pton(AF_INET, server_ip, &(local_addr.sin_addr));
|
|
local_addr.sin_port = htons(server_port);
|
|
bzero(&(local_addr.sin_zero), 8);
|
|
if (bind(listenfd_, (struct sockaddr*)&local_addr, sizeof(local_addr)) < 0) {
|
|
close("bind error");
|
|
return false;
|
|
}
|
|
if (listen(listenfd_, 5) < 0) {
|
|
close("listen error");
|
|
return false;
|
|
}
|
|
|
|
return true;
|
|
};
|
|
|
|
void close(const char* reason) {
|
|
if (listenfd_ >= 0) {
|
|
saveError(reason);
|
|
::close(listenfd_);
|
|
listenfd_ = -1;
|
|
}
|
|
}
|
|
|
|
const char* getLastError() { return last_error_; };
|
|
|
|
~SocketTcpServer() { close("destruct"); }
|
|
|
|
bool accept2(TcpConnection& conn) {
|
|
struct sockaddr_in clientaddr;
|
|
socklen_t addr_len = sizeof(clientaddr);
|
|
int fd = ::accept(listenfd_, (struct sockaddr*)&(clientaddr), &addr_len);
|
|
if (fd < 0) {
|
|
return false;
|
|
}
|
|
if (!conn.open(fd)) {
|
|
return false;
|
|
}
|
|
return true;
|
|
}
|
|
|
|
private:
|
|
void saveError(const char* msg) { snprintf(last_error_, sizeof(last_error_), "%s %s", msg, strerror(errno)); }
|
|
|
|
int listenfd_ = -1;
|
|
char last_error_[64] = "";
|
|
};
|
|
|
|
inline uint64_t getns() {
|
|
timespec ts;
|
|
::clock_gettime(CLOCK_REALTIME, &ts);
|
|
return ts.tv_sec * 1000000000 + ts.tv_nsec;
|
|
}
|
|
|
|
static const uint8_t OPCODE_CONT = 0;
|
|
static const uint8_t OPCODE_TEXT = 1;
|
|
static const uint8_t OPCODE_BINARY = 2;
|
|
static const uint8_t OPCODE_CLOSE = 8;
|
|
static const uint8_t OPCODE_PING = 9;
|
|
static const uint8_t OPCODE_PONG = 10;
|
|
|
|
template<typename EventHandler, typename ConnUserData, bool RecvSegment, uint32_t RecvBufSize, bool SendMask>
|
|
class WSConnection
|
|
{
|
|
public:
|
|
ConnUserData user_data;
|
|
|
|
// get remote network address
|
|
bool getPeername(struct sockaddr_in& addr) { return conn.getPeername(addr); }
|
|
|
|
bool isConnected() { return conn.isConnected(); }
|
|
|
|
// if sending a msg of multiple segments, only set fin to true for the last one
|
|
void send(uint8_t opcode, const uint8_t* payload, uint32_t pl_len, bool fin = true) {
|
|
uint8_t h[14];
|
|
uint32_t h_len = 2;
|
|
if (opcode >> 3) // if control
|
|
fin = true;
|
|
else {
|
|
if (!send_fin) opcode = OPCODE_CONT;
|
|
send_fin = fin;
|
|
}
|
|
h[0] = (opcode & 15) | ((uint8_t)fin << 7);
|
|
h[1] = (uint8_t)SendMask << 7;
|
|
if (pl_len < 126) {
|
|
h[1] |= (uint8_t)pl_len;
|
|
}
|
|
else if (pl_len < 65536) {
|
|
h[1] |= 126;
|
|
*(uint16_t*)(h + 2) = htobe16(pl_len);
|
|
h_len += 2;
|
|
}
|
|
else {
|
|
h[1] |= 127;
|
|
*(uint64_t*)(h + 2) = htobe64(pl_len);
|
|
h_len += 8;
|
|
}
|
|
if (SendMask) { // for efficency and simplicity masking-key is always set to 0
|
|
*(uint32_t*)(h + h_len) = 0;
|
|
h_len += 4;
|
|
}
|
|
conn.write(h, h_len, true);
|
|
conn.write(payload, pl_len, false);
|
|
}
|
|
|
|
// clean close the connection with optional status_code and reason
|
|
void close(uint16_t status_code = 1005, const char* reason = "") {
|
|
*(uint16_t*)close_reason = htobe16(status_code);
|
|
uint32_t reason_len = snprintf((char*)close_reason + 2, sizeof(close_reason) - 2, "%s", reason);
|
|
if (status_code != 1005) {
|
|
send(OPCODE_CLOSE, close_reason, 2 + reason_len);
|
|
}
|
|
else
|
|
send(OPCODE_CLOSE, nullptr, 0);
|
|
conn.close("clean close");
|
|
}
|
|
|
|
protected:
|
|
template<typename, typename, bool, uint32_t, uint32_t>
|
|
friend class WSServer;
|
|
|
|
void init(uint64_t expire) {
|
|
open = false;
|
|
send_fin = true;
|
|
*(uint16_t*)close_reason = htobe16(1006);
|
|
close_reason[2] = 0;
|
|
frame_size = 0;
|
|
expire_time = expire;
|
|
}
|
|
|
|
uint32_t handleWSMsg(EventHandler* handler, uint8_t* data, uint32_t size) {
|
|
// we might read a little more bytes beyond size, which is okey
|
|
const uint8_t* data_end = data + size;
|
|
uint8_t opcode = data[0] & 15;
|
|
bool beg = opcode != OPCODE_CONT, fin = data[0] >> 7; //, control = opcode >> 3;
|
|
bool mask = data[1] >> 7;
|
|
uint8_t mask_key[4];
|
|
uint64_t pl_len = data[1] & 127;
|
|
data += 2;
|
|
if (pl_len == 126) {
|
|
pl_len = be16toh(*(uint16_t*)data);
|
|
data += 2;
|
|
}
|
|
else if (pl_len == 127) {
|
|
pl_len = be64toh(*(uint64_t*)data) & ~(1ULL << 63);
|
|
data += 8;
|
|
}
|
|
if (mask) {
|
|
*(uint32_t*)mask_key = *(uint32_t*)data;
|
|
data += 4;
|
|
}
|
|
if (data_end - data < (int64_t)pl_len) {
|
|
if (size + (data + pl_len - data_end) > RecvBufSize) close(1009);
|
|
return size;
|
|
}
|
|
if (mask) {
|
|
for (uint64_t i = 0; i < pl_len; i++) data[i] ^= mask_key[i & 3];
|
|
}
|
|
if (RecvSegment || (beg && fin)) {
|
|
if (opcode == OPCODE_CLOSE) {
|
|
uint16_t status_code = 1005;
|
|
char reason[128] = {0};
|
|
if (pl_len >= 2) {
|
|
status_code = be16toh(*(uint16_t*)data);
|
|
uint64_t reason_len = std::min(sizeof(reason) - 1, pl_len - 2);
|
|
memcpy(reason, data + 2, reason_len);
|
|
reason[reason_len] = 0;
|
|
}
|
|
close(status_code, reason);
|
|
}
|
|
else {
|
|
#if __cplusplus >= 201703L
|
|
if constexpr (RecvSegment) {
|
|
#else
|
|
if (RecvSegment) {
|
|
#endif
|
|
if (beg) recv_opcode = opcode;
|
|
handler->onWSSegment(*this, recv_opcode, data, pl_len, frame_size, fin);
|
|
if (fin)
|
|
frame_size = 0;
|
|
else
|
|
frame_size += pl_len;
|
|
}
|
|
else
|
|
handler->onWSMsg(*this, opcode, data, pl_len);
|
|
}
|
|
}
|
|
#if __cplusplus >= 201703L
|
|
else if constexpr (!RecvSegment) {
|
|
#else
|
|
else {
|
|
#endif
|
|
if (frame_size + pl_len > RecvBufSize)
|
|
close(1009);
|
|
else {
|
|
memcpy(frame + frame_size, data, pl_len);
|
|
frame_size += pl_len;
|
|
if (beg) recv_opcode = opcode;
|
|
if (fin) {
|
|
handler->onWSMsg(*this, recv_opcode, frame, frame_size);
|
|
frame_size = 0;
|
|
}
|
|
}
|
|
}
|
|
return data_end - (data + pl_len);
|
|
}
|
|
|
|
void handleWSClose(EventHandler* handler) {
|
|
uint16_t status_code = be16toh(*(uint16_t*)close_reason);
|
|
const char* reason = (const char*)close_reason + 2;
|
|
if (status_code == 1006) reason = conn.getLastError();
|
|
handler->onWSClose(*this, status_code, reason);
|
|
}
|
|
|
|
bool open;
|
|
bool send_fin;
|
|
uint8_t recv_opcode;
|
|
uint32_t frame_size;
|
|
uint64_t expire_time;
|
|
uint8_t frame[RecvSegment ? 0 : RecvBufSize];
|
|
typename SocketTcpServer<RecvBufSize>::TcpConnection conn;
|
|
uint8_t close_reason[128]; // first 2 bytes are status_code(big endian)
|
|
};
|
|
|
|
template<typename EventHandler, typename ConnUserData = char, bool RecvSegment = false, uint32_t RecvBufSize = 4096,
|
|
typename ConnectionType = WSConnection<EventHandler, ConnUserData, RecvSegment, RecvBufSize, true>>
|
|
class WSClient : public ConnectionType
|
|
{
|
|
public:
|
|
using Connection = ConnectionType;
|
|
// using Connection = WSConnection<EventHandler, ConnUserData, RecvSegment, RecvBufSize, true>;
|
|
|
|
const char* getLastError() { return this->conn.getLastError(); }
|
|
|
|
// timeout: connect timeout in milliseconds, 0 means no limit
|
|
// if failed, call getLastError() for the reason
|
|
bool wsConnect(uint64_t timeout, const char* server_ip, uint16_t server_port, const char* request_uri,
|
|
const char* host, const char* origin = nullptr, const char* protocol = nullptr,
|
|
const char* extensions = nullptr, char* resp_protocol = nullptr, uint32_t resp_protocol_size = 0,
|
|
char* resp_extensions = nullptr, uint32_t resp_extensions_size = 0) {
|
|
uint64_t now = getns();
|
|
uint64_t expire = timeout > 0 ? now + timeout * 1000000 : std::numeric_limits<uint64_t>::max();
|
|
if (!this->conn.connect(server_ip, server_port)) return false;
|
|
if (getns() > expire) {
|
|
this->conn.close("timeout");
|
|
return false;
|
|
}
|
|
this->init(expire);
|
|
char req[2048];
|
|
uint32_t req_len =
|
|
snprintf(req, sizeof(req),
|
|
"GET %s HTTP/1.1\r\nHost: %s\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: "
|
|
"dGhlIHNhbXBsZSBub25jZQ==\r\nSec-WebSocket-Version: 13\r\n",
|
|
request_uri, host);
|
|
if (origin) req_len += snprintf(req + req_len, sizeof(req) - req_len, "Origin: %s\r\n", origin);
|
|
if (protocol) req_len += snprintf(req + req_len, sizeof(req) - req_len, "Sec-WebSocket-Protocol: %s\r\n", protocol);
|
|
if (extensions)
|
|
req_len += snprintf(req + req_len, sizeof(req) - req_len, "Sec-WebSocket-Extensions: %s\r\n", extensions);
|
|
req_len += snprintf(req + req_len, sizeof(req) - req_len, "\r\n");
|
|
if (req_len >= sizeof(req) - 1) {
|
|
this->conn.close("request msg too long");
|
|
return false;
|
|
}
|
|
this->conn.write((uint8_t*)req, req_len);
|
|
while (!this->open && this->isConnected()) {
|
|
this->conn.read([&](const char* data, uint32_t size) -> uint32_t {
|
|
const char* data_end = data + size;
|
|
bool status_code_checked = false, upgrade_checked = false, connection_checked = false, accept_checked = false;
|
|
while (true) {
|
|
const char* ln = (char*)memchr(data, '\n', data_end - data);
|
|
if (!ln) return size;
|
|
if (*--ln != '\r') break;
|
|
if (!status_code_checked) { // first line
|
|
if (memcmp(data, "HTTP/", 5)) break;
|
|
const char* status_code = (char*)memchr(data, ' ', ln - data);
|
|
if (!status_code) break;
|
|
while (*status_code == ' ') status_code++;
|
|
if (memcmp(status_code, "101 ", 4)) break;
|
|
status_code_checked = true;
|
|
}
|
|
else {
|
|
const char* val_end = ln;
|
|
while (val_end[-1] == ' ') val_end--;
|
|
if (val_end == data) { // end of headers
|
|
if (!upgrade_checked || !connection_checked || !accept_checked) break;
|
|
this->open = true;
|
|
return data_end - ln - 2;
|
|
}
|
|
const char* colon = (char*)memchr(data, ':', ln - data);
|
|
if (!colon) break;
|
|
const char* val = colon + 1;
|
|
while (*val == ' ') val++;
|
|
uint32_t key_len = colon - data;
|
|
uint32_t val_len = val_end - val;
|
|
if (key_len == 7 && !memcmp(data, "Upgrade", 7)) {
|
|
if (memcmp(val, "websocket", 9)) break;
|
|
upgrade_checked = true;
|
|
}
|
|
else if (key_len == 10 && !memcmp(data, "Connection", 10)) {
|
|
if (!memcmp(val, "Upgrade", 7)) connection_checked = true;
|
|
}
|
|
else if (key_len == 20 && !memcmp(data, "Sec-WebSocket-Accept", 20)) {
|
|
if (val_len != 28 || memcmp(val, "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=", 28)) break;
|
|
accept_checked = true;
|
|
}
|
|
else if (key_len == 22 && !memcmp(data, "Sec-WebSocket-Protocol", 22) && resp_protocol_size > 0) {
|
|
uint32_t cp_len = std::min(resp_protocol_size - 1, val_len);
|
|
memcpy(resp_protocol, val, cp_len);
|
|
resp_protocol[cp_len] = 0;
|
|
}
|
|
else if (key_len == 24 && !memcmp(data, "Sec-WebSocket-Extensions", 24) && resp_extensions_size > 0) {
|
|
uint32_t cp_len = std::min(resp_extensions_size - 1, val_len);
|
|
memcpy(resp_extensions, val, cp_len);
|
|
resp_extensions[cp_len] = 0;
|
|
}
|
|
}
|
|
data = ln + 2; // skip \r\n
|
|
}
|
|
this->conn.close("request failed");
|
|
return size;
|
|
});
|
|
if (getns() > expire) this->conn.close("timeout");
|
|
}
|
|
return this->isConnected();
|
|
}
|
|
|
|
void poll(EventHandler* handler) {
|
|
this->conn.read([&](const char* data, uint32_t size) { return this->handleWSMsg(handler, (uint8_t*)data, size); });
|
|
if (!this->isConnected()) this->handleWSClose(handler);
|
|
}
|
|
};
|
|
|
|
template<typename EventHandler, typename ConnUserData = char, bool RecvSegment = false, uint32_t RecvBufSize = 4096,
|
|
uint32_t MaxConns = 10>
|
|
class WSServer
|
|
{
|
|
public:
|
|
using TcpServer = SocketTcpServer<RecvBufSize>;
|
|
using Connection = WSConnection<EventHandler, ConnUserData, RecvSegment, RecvBufSize, false>;
|
|
|
|
WSServer() {
|
|
for (int i = 0; i < MaxConns; i++) {
|
|
conns_[i] = conns_data_ + i;
|
|
}
|
|
}
|
|
|
|
const char* getLastError() { return server_.getLastError(); }
|
|
|
|
// newconn_timeout: new tcp connection max inactive time in milliseconds, 0 means no limit
|
|
// openconn_timeout: open ws connection max inactive time in milliseconds, 0 means no limit
|
|
// if failed, call getLastError() for the reason
|
|
bool init(const char* server_ip, uint16_t server_port, uint64_t newconn_timeout = 0, uint64_t openconn_timeout = 0) {
|
|
newconn_timeout_ = newconn_timeout * 1000000;
|
|
openconn_timeout_ = openconn_timeout * 1000000;
|
|
return server_.init("", server_ip, server_port);
|
|
}
|
|
|
|
void poll(EventHandler* handler) {
|
|
uint64_t now = getns();
|
|
uint64_t new_expire = newconn_timeout_ ? now + newconn_timeout_ : std::numeric_limits<uint64_t>::max();
|
|
uint64_t open_expire = openconn_timeout_ ? now + openconn_timeout_ : std::numeric_limits<uint64_t>::max();
|
|
if (conns_cnt_ < MaxConns) {
|
|
Connection& new_conn = *conns_[conns_cnt_];
|
|
if (server_.accept2(new_conn.conn)) {
|
|
new_conn.init(new_expire);
|
|
conns_cnt_++;
|
|
}
|
|
}
|
|
for (int i = 0; i < conns_cnt_;) {
|
|
Connection& conn = *conns_[i];
|
|
conn.conn.read([&](const char* data, uint32_t size) {
|
|
uint32_t remaining =
|
|
conn.open ? conn.handleWSMsg(handler, (uint8_t*)data, size) : handleHttpRequest(handler, conn, data, size);
|
|
if (remaining < size) conn.expire_time = conn.open ? open_expire : new_expire;
|
|
return remaining;
|
|
});
|
|
if (now > conn.expire_time) conn.conn.close("timeout");
|
|
if (conn.isConnected())
|
|
i++;
|
|
else {
|
|
if (conn.open) conn.handleWSClose(handler);
|
|
std::swap(conns_[i], conns_[--conns_cnt_]);
|
|
}
|
|
}
|
|
}
|
|
|
|
void sendMsg(const std::string &msg)
|
|
{
|
|
for (int i = 0; i < conns_cnt_; i++)
|
|
{
|
|
Connection &conn = *conns_[i];
|
|
if (conn.isConnected())
|
|
{
|
|
conn.send(websocket::OPCODE_TEXT, (const uint8_t *)msg.data(), msg.size());
|
|
}
|
|
else
|
|
{
|
|
conn.close();
|
|
}
|
|
}
|
|
}
|
|
|
|
private:
|
|
static uint32_t rol(uint32_t value, uint32_t bits) { return (value << bits) | (value >> (32 - bits)); }
|
|
// Be cautious that *in* will be modified and up to 64 bytes will be appended, so make sure in buffer is long enough
|
|
static uint32_t sha1base64(uint8_t* in, uint64_t in_len, char* out) {
|
|
uint32_t h0[5] = {0x67452301, 0xEFCDAB89, 0x98BADCFE, 0x10325476, 0xC3D2E1F0};
|
|
uint64_t total_len = in_len;
|
|
in[total_len++] = 0x80;
|
|
int padding_size = (64 - (total_len + 8) % 64) % 64;
|
|
while (padding_size--) in[total_len++] = 0;
|
|
for (uint64_t i = 0; i < total_len; i += 4) {
|
|
uint32_t& w = *(uint32_t*)(in + i);
|
|
w = be32toh(w);
|
|
}
|
|
*(uint32_t*)(in + total_len) = (uint32_t)(in_len >> 29);
|
|
*(uint32_t*)(in + total_len + 4) = (uint32_t)(in_len << 3);
|
|
for (uint8_t* in_end = in + total_len + 8; in < in_end; in += 64) {
|
|
uint32_t* w = (uint32_t*)in;
|
|
uint32_t h[5];
|
|
memcpy(h, h0, sizeof(h));
|
|
for (uint32_t i = 0, j = 0; i < 80; i++, j += 4) {
|
|
uint32_t &a = h[j % 5], &b = h[(j + 1) % 5], &c = h[(j + 2) % 5], &d = h[(j + 3) % 5], &e = h[(j + 4) % 5];
|
|
if (i >= 16) w[i & 15] = rol(w[(i + 13) & 15] ^ w[(i + 8) & 15] ^ w[(i + 2) & 15] ^ w[i & 15], 1);
|
|
if (i < 40) {
|
|
if (i < 20)
|
|
e += ((b & (c ^ d)) ^ d) + 0x5A827999;
|
|
else
|
|
e += (b ^ c ^ d) + 0x6ED9EBA1;
|
|
}
|
|
else {
|
|
if (i < 60)
|
|
e += (((b | c) & d) | (b & c)) + 0x8F1BBCDC;
|
|
else
|
|
e += (b ^ c ^ d) + 0xCA62C1D6;
|
|
}
|
|
e += w[i & 15] + rol(a, 5);
|
|
b = rol(b, 30);
|
|
}
|
|
for (int i = 0; i < 5; i++) h0[i] += h[i];
|
|
}
|
|
const char* base64tb = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
|
|
uint32_t triples[7] = {h0[0] >> 8,
|
|
(h0[0] << 16) | (h0[1] >> 16),
|
|
(h0[1] << 8) | (h0[2] >> 24),
|
|
h0[2],
|
|
h0[3] >> 8,
|
|
(h0[3] << 16) | (h0[4] >> 16),
|
|
h0[4] << 8};
|
|
for (uint32_t i = 0; i < 7; i++) {
|
|
out[i * 4] = base64tb[(triples[i] >> 18) & 63];
|
|
out[i * 4 + 1] = base64tb[(triples[i] >> 12) & 63];
|
|
out[i * 4 + 2] = base64tb[(triples[i] >> 6) & 63];
|
|
out[i * 4 + 3] = base64tb[triples[i] & 63];
|
|
}
|
|
out[27] = '=';
|
|
return 28;
|
|
}
|
|
|
|
uint32_t handleHttpRequest(EventHandler* handler, Connection& conn, const char* data, uint32_t size) {
|
|
const char* data_end = data + size;
|
|
const int ValueBufSize = 128;
|
|
char request_uri[1024] = {0};
|
|
char host[ValueBufSize] = {0};
|
|
char origin[ValueBufSize] = {0};
|
|
char wskey[ValueBufSize] = {0};
|
|
char wsprotocol[ValueBufSize] = {0};
|
|
char wsextensions[ValueBufSize] = {0};
|
|
bool upgrade_checked = false, connection_checked = false, wsversion_checked = false;
|
|
while (true) {
|
|
const char* ln = (char*)memchr(data, '\n', data_end - data);
|
|
if (!ln) return size;
|
|
if (*--ln != '\r') break;
|
|
if (request_uri[0] == 0) { // first line
|
|
if (memcmp(data, "GET ", 4)) break;
|
|
data += 4;
|
|
while (*data == ' ') data++;
|
|
const char* uri_end = (char*)memchr(data, ' ', ln - data);
|
|
uint32_t uri_len = uri_end - data;
|
|
if (!uri_end || uri_len >= sizeof(request_uri)) break;
|
|
memcpy(request_uri, data, uri_len);
|
|
request_uri[uri_len] = 0;
|
|
}
|
|
else {
|
|
const char* val_end = ln;
|
|
while (val_end[-1] == ' ') val_end--;
|
|
if (val_end == data) { // end of headers
|
|
if (!host[0] || !wskey[0] || !upgrade_checked || !connection_checked || !wsversion_checked) break;
|
|
char resp_wsprotocol[ValueBufSize] = {0};
|
|
char resp_wsextensions[ValueBufSize] = {0};
|
|
char resp[1024];
|
|
uint32_t resp_len = 0;
|
|
bool accept = handler->onWSConnect(
|
|
conn, request_uri, host, origin[0] ? origin : nullptr, wsprotocol[0] ? wsprotocol : nullptr,
|
|
wsextensions[0] ? wsextensions : nullptr, resp_wsprotocol, ValueBufSize, resp_wsextensions, ValueBufSize);
|
|
if (accept) {
|
|
conn.open = true;
|
|
memcpy(wskey + 24, "258EAFA5-E914-47DA-95CA-C5AB0DC85B11", 36);
|
|
char accept_str[32];
|
|
accept_str[sha1base64((uint8_t*)wskey, 24 + 36, accept_str)] = 0;
|
|
resp_len = sprintf(resp,
|
|
"HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: "
|
|
"Upgrade\r\nSec-WebSocket-Accept: %s\r\n",
|
|
accept_str);
|
|
}
|
|
else {
|
|
resp_len = sprintf(resp, "HTTP/1.1 403 Forbidden\r\nSec-WebSocket-Version: 13\r\n");
|
|
}
|
|
if (resp_wsprotocol[0])
|
|
resp_len += sprintf(resp + resp_len, "Sec-WebSocket-Protocol: %s\r\n", resp_wsprotocol);
|
|
if (resp_wsextensions[0])
|
|
resp_len += sprintf(resp + resp_len, "Sec-WebSocket-Extensions: %s\r\n", resp_wsextensions);
|
|
resp_len += sprintf(resp + resp_len, "\r\n");
|
|
conn.conn.write((uint8_t*)resp, resp_len);
|
|
return data_end - ln - 2;
|
|
}
|
|
const char* colon = (char*)memchr(data, ':', ln - data);
|
|
if (!colon) break;
|
|
const char* val = colon + 1;
|
|
while (*val == ' ') val++;
|
|
uint32_t key_len = colon - data;
|
|
uint32_t val_len = val_end - val;
|
|
if (val_len < ValueBufSize) {
|
|
if (key_len == 4 && !memcmp(data, "Host", 4)) {
|
|
memcpy(host, val, val_len);
|
|
host[val_len] = 0;
|
|
}
|
|
else if (key_len == 6 && !memcmp(data, "Origin", 6)) {
|
|
memcpy(origin, val, val_len);
|
|
origin[val_len] = 0;
|
|
}
|
|
else if (key_len == 7 && !memcmp(data, "Upgrade", 7)) {
|
|
if (memcmp(val, "websocket", 9)) break;
|
|
upgrade_checked = true;
|
|
}
|
|
else if (key_len == 10 && !memcmp(data, "Connection", 10)) {
|
|
if (!memcmp(val, "Upgrade", 7)) connection_checked = true;
|
|
}
|
|
else if (key_len == 17 && !memcmp(data, "Sec-WebSocket-Key", 17)) {
|
|
if (val_len != 24) break;
|
|
memcpy(wskey, val, val_len);
|
|
}
|
|
else if (key_len == 21 && !memcmp(data, "Sec-WebSocket-Version", 21)) {
|
|
if (val_len != 2 || memcmp(val, "13", 2)) break;
|
|
wsversion_checked = true;
|
|
}
|
|
else if (key_len == 22 && !memcmp(data, "Sec-WebSocket-Protocol", 22)) {
|
|
memcpy(wsprotocol, val, val_len);
|
|
wsprotocol[val_len] = 0;
|
|
}
|
|
else if (key_len == 24 && !memcmp(data, "Sec-WebSocket-Extensions", 24)) {
|
|
memcpy(wsextensions, val, val_len);
|
|
wsextensions[val_len] = 0;
|
|
}
|
|
}
|
|
}
|
|
data = ln + 2; // skip \r\n
|
|
}
|
|
const char* resp400 = "HTTP/1.1 400 Bad Request\r\nSec-WebSocket-Version: 13\r\n\r\n";
|
|
conn.conn.write((uint8_t*)resp400, strlen(resp400));
|
|
conn.conn.close("bad request");
|
|
return size;
|
|
}
|
|
|
|
private:
|
|
uint64_t newconn_timeout_;
|
|
uint64_t openconn_timeout_;
|
|
TcpServer server_;
|
|
|
|
uint32_t conns_cnt_ = 0;
|
|
Connection* conns_[MaxConns];
|
|
Connection conns_data_[MaxConns];
|
|
};
|
|
|
|
} // namespace websocket
|