only sync lanes once during wave broadcast

This commit is contained in:
jacob 2025-12-06 10:25:31 -06:00
parent 97237b5ed9
commit fd1ada7fe5
4 changed files with 82 additions and 35 deletions

View File

@ -7,18 +7,18 @@ void WaveSyncEx(WaveLaneCtx *lane, u64 spin_count)
i32 lanes_count = wave->lanes_count;
if (lanes_count > 0)
{
i64 barrier_gen = Atomic64Fetch(&wave->barrier_gen.v);
i32 blocked_count = Atomic32FetchAdd(&wave->barrier_blocked_count, 1) + 1;
i64 sync_gen = Atomic64Fetch(&wave->sync_gen.v);
i32 blocked_count = Atomic32FetchAdd(&wave->sync_count.v, 1) + 1;
if (blocked_count == lanes_count)
{
Atomic32Set(&wave->barrier_blocked_count, 0);
Atomic64FetchAdd(&wave->barrier_gen.v, barrier_gen + 1);
FutexWakeNeq(&wave->barrier_gen.v);
Atomic32Set(&wave->sync_count.v, 0);
Atomic64FetchAdd(&wave->sync_gen.v, sync_gen + 1);
FutexWakeNeq(&wave->sync_gen.v);
}
else
{
u64 remaining_spins = spin_count;
while (Atomic64Fetch(&wave->barrier_gen.v) == barrier_gen)
while (Atomic64Fetch(&wave->sync_gen.v) == sync_gen)
{
if (remaining_spins > 0)
{
@ -27,7 +27,7 @@ void WaveSyncEx(WaveLaneCtx *lane, u64 spin_count)
}
else
{
FutexYieldNeq(&wave->barrier_gen.v, &barrier_gen, sizeof(barrier_gen));
FutexYieldNeq(&wave->sync_gen.v, &sync_gen, sizeof(sync_gen));
}
}
}
@ -37,17 +37,69 @@ void WaveSyncEx(WaveLaneCtx *lane, u64 spin_count)
void WaveSyncBroadcastEx_(WaveLaneCtx *lane, u32 broadcast_lane_idx, void *broadcast_ptr, u64 broadcast_size, u64 spin_count)
{
WaveCtx *wave = lane->wave;
u32 lane_idx = lane->idx;
if (lane_idx == broadcast_lane_idx)
i32 lanes_count = wave->lanes_count;
if (lanes_count > 1)
{
wave->barrier_broadcast_data = broadcast_ptr;
u32 lane_idx = lane->idx;
if (lane_idx == broadcast_lane_idx)
{
/* Broadcast */
wave->broadcast_data = broadcast_ptr;
i64 ack_gen = Atomic64Fetch(&wave->ack_gen.v);
lane->seen_broadcast_gen = Atomic64FetchAdd(&wave->broadcast_gen.v, 1) + 1;
FutexWakeNeq(&wave->broadcast_gen.v);
/* Wait for ack */
{
u64 remaining_spins = spin_count;
while (Atomic64Fetch(&wave->ack_gen.v) == ack_gen)
{
if (remaining_spins > 0)
{
--remaining_spins;
_mm_pause();
}
else
{
FutexYieldNeq(&wave->ack_gen.v, &ack_gen, sizeof(ack_gen));
}
}
}
}
else
{
/* Wait for broadcast */
i64 seen_broadcast_gen = lane->seen_broadcast_gen++;
{
u64 remaining_spins = spin_count;
while (Atomic64Fetch(&wave->broadcast_gen.v) == seen_broadcast_gen)
{
if (remaining_spins > 0)
{
--remaining_spins;
_mm_pause();
}
else
{
FutexYieldNeq(&wave->broadcast_gen.v, &seen_broadcast_gen, sizeof(seen_broadcast_gen));
}
}
}
/* Copy broadcast data */
CopyBytes(broadcast_ptr, wave->broadcast_data, broadcast_size);
/* Ack */
i32 ack_count = Atomic32FetchAdd(&wave->ack_count.v, 1) + 1;
if (ack_count == lanes_count - 1)
{
Atomic32Set(&wave->ack_count.v, 0);
Atomic64FetchAdd(&wave->ack_gen.v, 1);
FutexWakeNeq(&wave->ack_gen.v);
}
}
}
WaveSyncEx(lane, spin_count);
if (lane_idx != broadcast_lane_idx)
{
CopyBytes(broadcast_ptr, wave->barrier_broadcast_data, broadcast_size);
}
WaveSyncEx(lane, spin_count);
}
void SetWaveLaneDefaultSpin(WaveLaneCtx *lane, u64 n)

View File

@ -3,21 +3,27 @@
#define DefaultWaveLaneSpinCount 500
Struct(WaveCtx)
AlignedStruct(WaveCtx, CachelineSize)
{
i32 lanes_count;
/* Barrier */
void *barrier_broadcast_data;
Atomic32 barrier_blocked_count;
Atomic64Padded barrier_gen;
/* Sync barrier */
Atomic32Padded sync_count;
Atomic64Padded sync_gen;
/* Broadcast barrier */
void *broadcast_data;
Atomic64Padded broadcast_gen;
Atomic32Padded ack_count;
Atomic64Padded ack_gen;
};
Struct(WaveLaneCtx)
AlignedStruct(WaveLaneCtx, CachelineSize)
{
i32 idx;
WaveCtx *wave;
u64 default_spin_count;
i64 seen_broadcast_gen;
};
typedef void WaveLaneEntryFunc(WaveLaneCtx *lane, void *udata);

View File

@ -15,7 +15,6 @@ void W32_InitCurrentThread(String name)
Arena *perm = PermArena();
/* Fixme: Set thread name */
wchar_t *thread_name_wstr = WstrFromString(perm, name);
SetThreadDescription(GetCurrentThread(), thread_name_wstr);
@ -37,24 +36,14 @@ DWORD WINAPI W32_ThreadProc(LPVOID thread_args_vp)
void DispatchWave(String name, u32 num_lanes, WaveLaneEntryFunc *entry, void *udata)
{
/* FIXME: Impl */
Arena *perm = PermArena();
WaveCtx *wave_ctx = 0;
{
PushAlign(perm, CachelineSize);
wave_ctx = PushStruct(perm, WaveCtx);
PushAlign(perm, CachelineSize);
}
WaveCtx *wave_ctx = PushStruct(perm, WaveCtx);
wave_ctx->lanes_count = num_lanes;
for (u32 lane_idx = 0; lane_idx < num_lanes; ++lane_idx)
{
PushAlign(perm, CachelineSize);
WaveLaneCtx *lane_ctx = PushStruct(perm, WaveLaneCtx);
PushAlign(perm, CachelineSize);
lane_ctx->idx = lane_idx;
lane_ctx->wave = wave_ctx;
lane_ctx->default_spin_count = DefaultWaveLaneSpinCount;

View File

@ -1125,6 +1125,6 @@ void StartupLayers(void)
{
OS_Startup();
CpuTopologyInfo cpu_info = GetCpuTopologyInfo();
i32 meta_lanes_count = cpu_info.num_logical_cores;
i32 meta_lanes_count = cpu_info.num_logical_cores - 1;
DispatchWave(Lit("Meta"), MaxI32(meta_lanes_count, 1), BuildEntryPoint, 0);
}