142 lines
3.5 KiB
C
142 lines
3.5 KiB
C
////////////////////////////////////////////////////////////
|
|
//~ Wave sync ops
|
|
|
|
void WaveSyncEx(WaveLaneCtx *lane, u64 spin_count)
|
|
{
|
|
WaveCtx *wave = lane->wave;
|
|
i32 lanes_count = wave->lanes_count;
|
|
if (lanes_count > 1)
|
|
{
|
|
i64 sync_gen = Atomic64Fetch(&wave->sync_gen.v);
|
|
i32 blocked_count = Atomic32FetchAdd(&wave->sync_count.v, 1) + 1;
|
|
if (blocked_count == lanes_count)
|
|
{
|
|
Atomic32Set(&wave->sync_count.v, 0);
|
|
Atomic64Set(&wave->sync_gen.v, sync_gen + 1);
|
|
FutexWakeNeq(&wave->sync_gen.v);
|
|
}
|
|
else
|
|
{
|
|
u64 remaining_spins = spin_count;
|
|
while (Atomic64Fetch(&wave->sync_gen.v) == sync_gen)
|
|
{
|
|
if (remaining_spins > 0)
|
|
{
|
|
--remaining_spins;
|
|
_mm_pause();
|
|
}
|
|
else
|
|
{
|
|
FutexYieldNeq(&wave->sync_gen.v, &sync_gen, sizeof(sync_gen), I64Max);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
void WaveSyncBroadcastEx_(WaveLaneCtx *lane, u32 broadcast_lane_idx, void *broadcast_ptr, u64 broadcast_size, u64 spin_count)
|
|
{
|
|
WaveCtx *wave = lane->wave;
|
|
i32 lanes_count = wave->lanes_count;
|
|
if (lanes_count > 1)
|
|
{
|
|
u32 lane_idx = lane->idx;
|
|
|
|
if (lane_idx == broadcast_lane_idx)
|
|
{
|
|
// Broadcast
|
|
wave->broadcast_data = broadcast_ptr;
|
|
i64 ack_gen = Atomic64Fetch(&wave->ack_gen.v);
|
|
lane->seen_broadcast_gen = Atomic64FetchAdd(&wave->broadcast_gen.v, 1) + 1;
|
|
FutexWakeNeq(&wave->broadcast_gen.v);
|
|
|
|
// Wait for ack
|
|
{
|
|
u64 remaining_spins = spin_count;
|
|
while (Atomic64Fetch(&wave->ack_gen.v) == ack_gen)
|
|
{
|
|
if (remaining_spins > 0)
|
|
{
|
|
--remaining_spins;
|
|
_mm_pause();
|
|
}
|
|
else
|
|
{
|
|
FutexYieldNeq(&wave->ack_gen.v, &ack_gen, sizeof(ack_gen), I64Max);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
else
|
|
{
|
|
// Wait for broadcast
|
|
i64 seen_broadcast_gen = lane->seen_broadcast_gen++;
|
|
{
|
|
u64 remaining_spins = spin_count;
|
|
while (Atomic64Fetch(&wave->broadcast_gen.v) == seen_broadcast_gen)
|
|
{
|
|
if (remaining_spins > 0)
|
|
{
|
|
--remaining_spins;
|
|
_mm_pause();
|
|
}
|
|
else
|
|
{
|
|
FutexYieldNeq(&wave->broadcast_gen.v, &seen_broadcast_gen, sizeof(seen_broadcast_gen), I64Max);
|
|
}
|
|
}
|
|
}
|
|
|
|
// Copy broadcast data
|
|
CopyBytes(broadcast_ptr, wave->broadcast_data, broadcast_size);
|
|
|
|
// Ack
|
|
i32 ack_count = Atomic32FetchAdd(&wave->ack_count.v, 1) + 1;
|
|
if (ack_count == lanes_count - 1)
|
|
{
|
|
Atomic32Set(&wave->ack_count.v, 0);
|
|
Atomic64FetchAdd(&wave->ack_gen.v, 1);
|
|
FutexWakeNeq(&wave->ack_gen.v);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
void SetWaveLaneDefaultSpin(WaveLaneCtx *lane, u64 n)
|
|
{
|
|
lane->default_spin_count = n;
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////
|
|
//~ Lane helpers
|
|
|
|
i32 WaveLaneIdxFromTaskIdx(WaveLaneCtx *lane, u64 task_idx)
|
|
{
|
|
WaveCtx *wave = lane->wave;
|
|
return task_idx % wave->lanes_count;
|
|
}
|
|
|
|
RngU64 WaveIdxRangeFromCount(WaveLaneCtx *lane, u64 tasks_count)
|
|
{
|
|
u64 lanes_count = lane->wave->lanes_count;
|
|
u64 lane_idx = lane->idx;
|
|
|
|
u64 tasks_per_lane = tasks_count / lanes_count;
|
|
u64 tasks_overflow = tasks_count % lanes_count;
|
|
|
|
u64 start = lane_idx * tasks_per_lane;
|
|
u64 end = start + tasks_per_lane;
|
|
if (lane_idx < tasks_overflow)
|
|
{
|
|
start += lane_idx;
|
|
end += lane_idx + 1;
|
|
}
|
|
else
|
|
{
|
|
start += tasks_overflow;
|
|
end += tasks_overflow;
|
|
}
|
|
|
|
return RNGU64(start, end);
|
|
}
|