only sync lanes once during wave broadcast

This commit is contained in:
jacob 2025-12-06 10:25:31 -06:00
parent 97237b5ed9
commit fd1ada7fe5
4 changed files with 82 additions and 35 deletions

View File

@ -7,18 +7,18 @@ void WaveSyncEx(WaveLaneCtx *lane, u64 spin_count)
i32 lanes_count = wave->lanes_count; i32 lanes_count = wave->lanes_count;
if (lanes_count > 0) if (lanes_count > 0)
{ {
i64 barrier_gen = Atomic64Fetch(&wave->barrier_gen.v); i64 sync_gen = Atomic64Fetch(&wave->sync_gen.v);
i32 blocked_count = Atomic32FetchAdd(&wave->barrier_blocked_count, 1) + 1; i32 blocked_count = Atomic32FetchAdd(&wave->sync_count.v, 1) + 1;
if (blocked_count == lanes_count) if (blocked_count == lanes_count)
{ {
Atomic32Set(&wave->barrier_blocked_count, 0); Atomic32Set(&wave->sync_count.v, 0);
Atomic64FetchAdd(&wave->barrier_gen.v, barrier_gen + 1); Atomic64FetchAdd(&wave->sync_gen.v, sync_gen + 1);
FutexWakeNeq(&wave->barrier_gen.v); FutexWakeNeq(&wave->sync_gen.v);
} }
else else
{ {
u64 remaining_spins = spin_count; u64 remaining_spins = spin_count;
while (Atomic64Fetch(&wave->barrier_gen.v) == barrier_gen) while (Atomic64Fetch(&wave->sync_gen.v) == sync_gen)
{ {
if (remaining_spins > 0) if (remaining_spins > 0)
{ {
@ -27,7 +27,7 @@ void WaveSyncEx(WaveLaneCtx *lane, u64 spin_count)
} }
else else
{ {
FutexYieldNeq(&wave->barrier_gen.v, &barrier_gen, sizeof(barrier_gen)); FutexYieldNeq(&wave->sync_gen.v, &sync_gen, sizeof(sync_gen));
} }
} }
} }
@ -37,17 +37,69 @@ void WaveSyncEx(WaveLaneCtx *lane, u64 spin_count)
void WaveSyncBroadcastEx_(WaveLaneCtx *lane, u32 broadcast_lane_idx, void *broadcast_ptr, u64 broadcast_size, u64 spin_count) void WaveSyncBroadcastEx_(WaveLaneCtx *lane, u32 broadcast_lane_idx, void *broadcast_ptr, u64 broadcast_size, u64 spin_count)
{ {
WaveCtx *wave = lane->wave; WaveCtx *wave = lane->wave;
u32 lane_idx = lane->idx; i32 lanes_count = wave->lanes_count;
if (lane_idx == broadcast_lane_idx) if (lanes_count > 1)
{ {
wave->barrier_broadcast_data = broadcast_ptr; u32 lane_idx = lane->idx;
if (lane_idx == broadcast_lane_idx)
{
/* Broadcast */
wave->broadcast_data = broadcast_ptr;
i64 ack_gen = Atomic64Fetch(&wave->ack_gen.v);
lane->seen_broadcast_gen = Atomic64FetchAdd(&wave->broadcast_gen.v, 1) + 1;
FutexWakeNeq(&wave->broadcast_gen.v);
/* Wait for ack */
{
u64 remaining_spins = spin_count;
while (Atomic64Fetch(&wave->ack_gen.v) == ack_gen)
{
if (remaining_spins > 0)
{
--remaining_spins;
_mm_pause();
}
else
{
FutexYieldNeq(&wave->ack_gen.v, &ack_gen, sizeof(ack_gen));
}
}
}
}
else
{
/* Wait for broadcast */
i64 seen_broadcast_gen = lane->seen_broadcast_gen++;
{
u64 remaining_spins = spin_count;
while (Atomic64Fetch(&wave->broadcast_gen.v) == seen_broadcast_gen)
{
if (remaining_spins > 0)
{
--remaining_spins;
_mm_pause();
}
else
{
FutexYieldNeq(&wave->broadcast_gen.v, &seen_broadcast_gen, sizeof(seen_broadcast_gen));
}
}
}
/* Copy broadcast data */
CopyBytes(broadcast_ptr, wave->broadcast_data, broadcast_size);
/* Ack */
i32 ack_count = Atomic32FetchAdd(&wave->ack_count.v, 1) + 1;
if (ack_count == lanes_count - 1)
{
Atomic32Set(&wave->ack_count.v, 0);
Atomic64FetchAdd(&wave->ack_gen.v, 1);
FutexWakeNeq(&wave->ack_gen.v);
}
}
} }
WaveSyncEx(lane, spin_count);
if (lane_idx != broadcast_lane_idx)
{
CopyBytes(broadcast_ptr, wave->barrier_broadcast_data, broadcast_size);
}
WaveSyncEx(lane, spin_count);
} }
void SetWaveLaneDefaultSpin(WaveLaneCtx *lane, u64 n) void SetWaveLaneDefaultSpin(WaveLaneCtx *lane, u64 n)

View File

@ -3,21 +3,27 @@
#define DefaultWaveLaneSpinCount 500 #define DefaultWaveLaneSpinCount 500
Struct(WaveCtx) AlignedStruct(WaveCtx, CachelineSize)
{ {
i32 lanes_count; i32 lanes_count;
/* Barrier */ /* Sync barrier */
void *barrier_broadcast_data; Atomic32Padded sync_count;
Atomic32 barrier_blocked_count; Atomic64Padded sync_gen;
Atomic64Padded barrier_gen;
/* Broadcast barrier */
void *broadcast_data;
Atomic64Padded broadcast_gen;
Atomic32Padded ack_count;
Atomic64Padded ack_gen;
}; };
Struct(WaveLaneCtx) AlignedStruct(WaveLaneCtx, CachelineSize)
{ {
i32 idx; i32 idx;
WaveCtx *wave; WaveCtx *wave;
u64 default_spin_count; u64 default_spin_count;
i64 seen_broadcast_gen;
}; };
typedef void WaveLaneEntryFunc(WaveLaneCtx *lane, void *udata); typedef void WaveLaneEntryFunc(WaveLaneCtx *lane, void *udata);

View File

@ -15,7 +15,6 @@ void W32_InitCurrentThread(String name)
Arena *perm = PermArena(); Arena *perm = PermArena();
/* Fixme: Set thread name */
wchar_t *thread_name_wstr = WstrFromString(perm, name); wchar_t *thread_name_wstr = WstrFromString(perm, name);
SetThreadDescription(GetCurrentThread(), thread_name_wstr); SetThreadDescription(GetCurrentThread(), thread_name_wstr);
@ -37,24 +36,14 @@ DWORD WINAPI W32_ThreadProc(LPVOID thread_args_vp)
void DispatchWave(String name, u32 num_lanes, WaveLaneEntryFunc *entry, void *udata) void DispatchWave(String name, u32 num_lanes, WaveLaneEntryFunc *entry, void *udata)
{ {
/* FIXME: Impl */
Arena *perm = PermArena(); Arena *perm = PermArena();
WaveCtx *wave_ctx = 0; WaveCtx *wave_ctx = PushStruct(perm, WaveCtx);
{
PushAlign(perm, CachelineSize);
wave_ctx = PushStruct(perm, WaveCtx);
PushAlign(perm, CachelineSize);
}
wave_ctx->lanes_count = num_lanes; wave_ctx->lanes_count = num_lanes;
for (u32 lane_idx = 0; lane_idx < num_lanes; ++lane_idx) for (u32 lane_idx = 0; lane_idx < num_lanes; ++lane_idx)
{ {
PushAlign(perm, CachelineSize);
WaveLaneCtx *lane_ctx = PushStruct(perm, WaveLaneCtx); WaveLaneCtx *lane_ctx = PushStruct(perm, WaveLaneCtx);
PushAlign(perm, CachelineSize);
lane_ctx->idx = lane_idx; lane_ctx->idx = lane_idx;
lane_ctx->wave = wave_ctx; lane_ctx->wave = wave_ctx;
lane_ctx->default_spin_count = DefaultWaveLaneSpinCount; lane_ctx->default_spin_count = DefaultWaveLaneSpinCount;

View File

@ -1125,6 +1125,6 @@ void StartupLayers(void)
{ {
OS_Startup(); OS_Startup();
CpuTopologyInfo cpu_info = GetCpuTopologyInfo(); CpuTopologyInfo cpu_info = GetCpuTopologyInfo();
i32 meta_lanes_count = cpu_info.num_logical_cores; i32 meta_lanes_count = cpu_info.num_logical_cores - 1;
DispatchWave(Lit("Meta"), MaxI32(meta_lanes_count, 1), BuildEntryPoint, 0); DispatchWave(Lit("Meta"), MaxI32(meta_lanes_count, 1), BuildEntryPoint, 0);
} }