power_play/src/sock_win32.c
2025-02-23 03:23:17 -06:00

452 lines
13 KiB
C

#include "sock.h"
#include "sys.h"
#include "log.h"
#include "arena.h"
#include "scratch.h"
#include "string.h"
#include "log.h"
#include "gstat.h"
#define WIN32_LEAN_AND_MEAN
#define UNICODE
#include <Windows.h>
#include <WinSock2.h>
#include <WS2tcpip.h>
#pragma comment(lib, "ws2_32.lib")
//#define MAX_IP_STR_LEN 46
#define MAX_POLL_FDS 64
struct win32_address {
i32 size;
i32 family;
union {
struct sockaddr_storage sas;
struct sockaddr sa;
struct sockaddr_in sin;
struct sockaddr_in6 sin6;
};
};
struct win32_sock {
SOCKET sock;
struct win32_sock *next_free;
};
/* ========================== *
* Global state
* ========================== */
GLOBAL struct {
WSADATA wsa_data;
struct arena win32_socks_arena;
struct sys_mutex win32_socks_mutex;
struct win32_sock *first_free_win32_sock;
} G = ZI, DEBUG_ALIAS(G, G_sock_win32);
/* ========================== *
* Startup
* ========================== */
struct sock_startup_receipt sock_startup(void)
{
/* Startup winsock */
WSAStartup(MAKEWORD(2, 2), &G.wsa_data);
G.win32_socks_arena = arena_alloc(GIGABYTE(64));
G.win32_socks_mutex = sys_mutex_alloc();
return (struct sock_startup_receipt) { 0 };
}
/* ========================== *
* Address
* ========================== */
INTERNAL struct sock_address sock_address_from_ip_port_cstr(char *ip_cstr, char *port_cstr)
{
struct sock_address res = ZI;
struct addrinfo hints = ZI;
hints.ai_family = AF_UNSPEC;
hints.ai_socktype = SOCK_DGRAM;
hints.ai_flags = AI_PASSIVE;
struct addrinfo *ai_res = NULL;
i32 status = getaddrinfo(ip_cstr, port_cstr, &hints, &ai_res);
if (status == 0) {
while (ai_res) {
if (ai_res->ai_family == AF_INET) {
struct sockaddr_in *sockaddr = (struct sockaddr_in *)ai_res->ai_addr;
res.valid = true;
res.family = SOCK_ADDRESS_FAMILY_IPV4;
res.portnb = sockaddr->sin_port;
CT_ASSERT(sizeof(sockaddr->sin_addr) == 4);
MEMCPY(res.ipnb, (void *)&sockaddr->sin_addr, 4);
break;
} else if (ai_res->ai_family == AF_INET6) {
/* TODO: Enable ipv6 */
#if 0
struct sockaddr_in6 *sockaddr = (struct sockaddr_in6 *)ai_res->ai_addr;
res.valid = true;
res.family = SOCK_ADDRESS_FAMILY_IPV6;
res.portnb = sockaddr->sin6_port;
CT_ASSERT(sizeof(sockaddr->sin6_addr) == 16);
MEMCPY(res.ipnb, (void *)&sockaddr->sin6_addr, 16);
break;
#endif
}
ai_res = ai_res->ai_next;
}
freeaddrinfo(ai_res);
}
return res;
}
struct sock_address sock_address_from_string(struct string str)
{
/* Parse string into ip & port */
u8 ip_buff[1024];
u8 port_buff[ARRAY_COUNT(ip_buff)];
char *ip_cstr = NULL;
char *port_cstr = NULL;
{
u64 colon_count = 0;
for (u64 i = 0; i < str.len; ++i) {
u8 c = str.text[i];
if (c == ':') {
++colon_count;
}
}
u64 ip_len = 0;
u64 port_len = 0;
u64 parse_len = min_u64(min_u64(str.len, ARRAY_COUNT(ip_buff) - 1), ARRAY_COUNT(port_buff) - 1);
if (colon_count > 1 && str.text[0] == '[') {
/* Parse ipv6 with port */
b32 parse_addr = true;
for (u64 i = 1; i < parse_len; ++i) {
u8 c = str.text[i];
if (parse_addr) {
if (c == ']') {
parse_addr = false;
} else {
ip_buff[ip_len] = c;
++ip_len;
}
} else if (c != ':') {
port_buff[port_len] = c;
++port_len;
}
}
} else if (colon_count == 1) {
/* Parse address with port */
b32 parse_addr = true;
for (u64 i = 0; i < parse_len; ++i) {
u8 c = str.text[i];
if (parse_addr) {
if (c == ':') {
parse_addr = false;
} else {
ip_buff[ip_len] = c;
++ip_len;
}
} else {
port_buff[port_len] = c;
++port_len;
}
}
} else {
/* Copy address without port */
ip_len = min_u64(str.len, ARRAY_COUNT(ip_buff) - 1);
MEMCPY(ip_buff, str.text, ip_len);
}
if (ip_len > 0) {
ip_buff[ip_len] = 0;
ip_cstr = (char *)ip_buff;
}
if (port_len > 0) {
port_buff[port_len] = 0;
port_cstr = (char *)port_buff;
}
}
struct sock_address res = sock_address_from_ip_port_cstr(ip_cstr, port_cstr);
return res;
}
struct sock_address sock_address_from_port(u16 port)
{
u8 port_buff[128];
char *port_cstr = NULL;
{
u8 port_buff_reverse[ARRAY_COUNT(port_buff)];
u64 port_len = 0;
while (port > 0 && port_len < (ARRAY_COUNT(port_buff) - 1)) {
u8 digit = port % 10;
port /= 10;
port_buff_reverse[port_len] = '0' + digit;
++port_len;
}
for (u64 i = 0; i < port_len; ++i) {
u64 j = port_len - 1 - i;
port_buff[i] = port_buff_reverse[j];
}
if (port_len > 0) {
port_buff[port_len] = 0;
port_cstr = (char *)port_buff;
}
}
struct sock_address res = sock_address_from_ip_port_cstr(NULL, port_cstr);
return res;
}
struct string sock_string_from_address(struct arena *arena, struct sock_address address)
{
struct string res = ZI;
if (address.family == SOCK_ADDRESS_FAMILY_IPV6) {
/* TODO */
} else {
u8 ip[4];
for (u32 i = 0; i < 4; ++i) {
ip[i] = ntohs(address.ipnb[i]);
}
u16 port = ntohs(address.portnb);
res = string_format(arena, LIT("%F.%F.%F.%F:%F"), FMT_UINT(ip[0]), FMT_UINT(ip[1]), FMT_UINT(ip[2]), FMT_UINT(ip[3]), FMT_UINT(port));
}
return res;
}
INTERNAL struct win32_address win32_address_from_sock_address(struct sock_address addr)
{
struct win32_address res = ZI;
if (addr.family == SOCK_ADDRESS_FAMILY_IPV4) {
res.family = AF_INET;
res.size = sizeof(struct sockaddr_in);
res.sin.sin_port = addr.portnb;
res.sin.sin_family = res.family;
MEMCPY(&res.sin.sin_addr, addr.ipnb, 4);
} else {
res.family = AF_INET6;
res.sin6.sin6_port = addr.portnb;
res.sin6.sin6_family = res.family;
res.size = sizeof(struct sockaddr_in6);
MEMCPY(&res.sin6.sin6_addr.s6_addr, addr.ipnb, 16);
}
return res;
}
/* If supplied address has ip INADDR_ANY (0), convert ip to localhost */
INTERNAL struct win32_address win32_address_convert_any_to_localhost(struct win32_address addr)
{
if (addr.family == AF_INET) {
u8 *bytes = (u8 *)&addr.sin.sin_addr;
b32 is_any = true;
for (u64 i = 0; i < 4; ++i) {
if (bytes[i] != 0) {
is_any = false;
break;
}
}
if (is_any) {
bytes[0] = 127;
bytes[3] = 1;
}
} else if (addr.family == AF_INET6) {
u8 *bytes = (u8 *)&addr.sin.sin_addr;
b32 is_any = true;
for (u64 i = 0; i < 16; ++i) {
if (bytes[i] != 0) {
is_any = false;
break;
}
}
if (is_any) {
bytes[15] = 1;
}
}
return addr;
}
INTERNAL struct sock_address sock_address_from_win32_address(struct win32_address ws_addr)
{
struct sock_address res = ZI;
if (ws_addr.family == AF_INET) {
res.family = SOCK_ADDRESS_FAMILY_IPV4;
res.portnb = ws_addr.sin.sin_port;
MEMCPY(res.ipnb, &ws_addr.sin.sin_addr, 4);
res.valid = true;
} else if (ws_addr.family == AF_INET6) {
res.family = SOCK_ADDRESS_FAMILY_IPV6;
res.portnb = ws_addr.sin6.sin6_port;
MEMCPY(res.ipnb, &ws_addr.sin6.sin6_addr.s6_addr, 16);
res.valid = true;
}
return res;
}
/* ========================== *
* Sock
* ========================== */
INTERNAL struct win32_sock *win32_sock_alloc(void)
{
struct win32_sock *ws = NULL;
{
struct sys_lock lock = sys_mutex_lock_e(&G.win32_socks_mutex);
if (G.first_free_win32_sock) {
ws = G.first_free_win32_sock;
G.first_free_win32_sock = ws->next_free;
} else {
ws = arena_push(&G.win32_socks_arena, struct win32_sock);
}
sys_mutex_unlock(&lock);
}
MEMZERO_STRUCT(ws);
return ws;
}
INTERNAL void win32_sock_release(struct win32_sock *ws)
{
struct sys_lock lock = sys_mutex_lock_e(&G.win32_socks_mutex);
ws->next_free = G.first_free_win32_sock;
G.first_free_win32_sock = ws;
sys_mutex_unlock(&lock);
}
struct sock *sock_alloc(u16 listen_port, u64 sndbuf_size, u64 rcvbuf_size)
{
struct win32_sock *ws = win32_sock_alloc();
struct sock_address addr = sock_address_from_port(listen_port);
struct win32_address bind_address = win32_address_from_sock_address(addr);
ws->sock = socket(bind_address.family, SOCK_DGRAM, IPPROTO_UDP);
{
i32 sb = sndbuf_size;
i32 rb = rcvbuf_size;
setsockopt(ws->sock, SOL_SOCKET, SO_SNDBUF, (char *)&sb, sizeof(sb));
setsockopt(ws->sock, SOL_SOCKET, SO_RCVBUF, (char *)&rb, sizeof(rb));
}
bind(ws->sock, &bind_address.sa, bind_address.size);
return (struct sock *)ws;
}
void sock_release(struct sock *sock)
{
struct win32_sock *ws = (struct win32_sock *)sock;
closesocket(ws->sock);
win32_sock_release(ws);
}
/* Send an empty dummy packet to wake anyone blocking on read (dumb hack since
* winsock doesn't have eventfd).
*
* TODO: Use WSAEvent and WSAWaitForMultipleEvents instead */
void sock_wake(struct sock *sock)
{
struct win32_sock *ws = (struct win32_sock *)sock;
/* Get bound address as localhost so we can write to it (if bound to INADDR_ANY) */
struct win32_address bind_address = ZI;
{
i32 len = sizeof(bind_address.sas);
getsockname(ws->sock, &bind_address.sa, &len);
bind_address.family = bind_address.sin.sin_family;
bind_address.size = len;
bind_address = win32_address_convert_any_to_localhost(bind_address);
}
/* Have sock send an empty dummy packet to itself to signal read available */
sendto(ws->sock, "", 0, 0, &bind_address.sa, bind_address.size);
}
/* ========================== *
* Read
* ========================== */
struct sock *sock_wait_for_available_read(struct sock_array socks, f32 timeout)
{
struct sock *res = NULL;
WSAPOLLFD fds[MAX_POLL_FDS] = ZI;
for (u32 i = 0; i < socks.count; ++i) {
struct win32_sock *ws = (struct win32_sock *)socks.socks[i];
fds[i].fd = ws->sock;
fds[i].events = POLLRDNORM;
}
i32 timeout_ms;
if (timeout == F32_INFINITY) {
timeout_ms = -1;
} else {
timeout_ms = (i32)(timeout * 1000);
}
WSAPoll(fds, socks.count, timeout_ms);
for (u64 i = 0; i < socks.count; ++i) {
if (fds[i].revents & POLLRDNORM) {
res = socks.socks[i];
break;
}
}
return res;
}
struct sock_read_result sock_read(struct sock *sock, struct string read_buff)
{
struct win32_sock *ws = (struct win32_sock *)sock;
struct sock_read_result res = ZI;
struct win32_address ws_addr = ZI;
ws_addr.size = sizeof(ws_addr.sas);
i32 size = recvfrom(ws->sock, (char *)read_buff.text, read_buff.len, 0, &ws_addr.sa, &ws_addr.size);
ws_addr.family = ws_addr.sin.sin_family;
res.address = sock_address_from_win32_address(ws_addr);
if (size >= 0) {
gstat_add(GSTAT_SOCK_BYTES_RECEIVED, size);
res.data.text = read_buff.text;
res.data.len = size;
res.valid = true;
} else {
#if RTC
i32 err = WSAGetLastError();
if (err != WSAEWOULDBLOCK && err != WSAETIMEDOUT && err != WSAECONNRESET) {
ASSERT(false);
}
#endif
}
return res;
}
/* ========================== *
* Write
* ========================== */
void sock_write(struct sock *sock, struct sock_address address, struct string data)
{
struct win32_sock *ws = (struct win32_sock *)sock;
struct win32_address ws_addr = win32_address_from_sock_address(address);
i32 size = sendto(ws->sock, (char *)data.text, data.len, 0, &ws_addr.sa, ws_addr.size);
if (size > 0) {
gstat_add(GSTAT_SOCK_BYTES_SENT, size);
}
#if RTC
if (size != (i32)data.len) {
i32 err = WSAGetLastError();
(UNUSED)err;
ASSERT(false);
}
#endif
}