power_play/src/base/base_wave.c

142 lines
3.5 KiB
C

////////////////////////////////////////////////////////////
//~ Wave sync ops
void WaveSyncEx(WaveLaneCtx *lane, u64 spin_count)
{
WaveCtx *wave = lane->wave;
i32 lanes_count = wave->lanes_count;
if (lanes_count > 1)
{
i64 sync_gen = Atomic64Fetch(&wave->sync_gen.v);
i32 blocked_count = Atomic32FetchAdd(&wave->sync_count.v, 1) + 1;
if (blocked_count == lanes_count)
{
Atomic32Set(&wave->sync_count.v, 0);
Atomic64Set(&wave->sync_gen.v, sync_gen + 1);
FutexWakeNeq(&wave->sync_gen.v);
}
else
{
u64 remaining_spins = spin_count;
while (Atomic64Fetch(&wave->sync_gen.v) == sync_gen)
{
if (remaining_spins > 0)
{
--remaining_spins;
_mm_pause();
}
else
{
FutexYieldNeq(&wave->sync_gen.v, &sync_gen, sizeof(sync_gen), I64Max);
}
}
}
}
}
void WaveSyncBroadcastEx_(WaveLaneCtx *lane, u32 broadcast_lane_idx, void *broadcast_ptr, u64 broadcast_size, u64 spin_count)
{
WaveCtx *wave = lane->wave;
i32 lanes_count = wave->lanes_count;
if (lanes_count > 1)
{
u32 lane_idx = lane->idx;
if (lane_idx == broadcast_lane_idx)
{
// Broadcast
wave->broadcast_data = broadcast_ptr;
i64 ack_gen = Atomic64Fetch(&wave->ack_gen.v);
lane->seen_broadcast_gen = Atomic64FetchAdd(&wave->broadcast_gen.v, 1) + 1;
FutexWakeNeq(&wave->broadcast_gen.v);
// Wait for ack
{
u64 remaining_spins = spin_count;
while (Atomic64Fetch(&wave->ack_gen.v) == ack_gen)
{
if (remaining_spins > 0)
{
--remaining_spins;
_mm_pause();
}
else
{
FutexYieldNeq(&wave->ack_gen.v, &ack_gen, sizeof(ack_gen), I64Max);
}
}
}
}
else
{
// Wait for broadcast
i64 seen_broadcast_gen = lane->seen_broadcast_gen++;
{
u64 remaining_spins = spin_count;
while (Atomic64Fetch(&wave->broadcast_gen.v) == seen_broadcast_gen)
{
if (remaining_spins > 0)
{
--remaining_spins;
_mm_pause();
}
else
{
FutexYieldNeq(&wave->broadcast_gen.v, &seen_broadcast_gen, sizeof(seen_broadcast_gen), I64Max);
}
}
}
// Copy broadcast data
CopyBytes(broadcast_ptr, wave->broadcast_data, broadcast_size);
// Ack
i32 ack_count = Atomic32FetchAdd(&wave->ack_count.v, 1) + 1;
if (ack_count == lanes_count - 1)
{
Atomic32Set(&wave->ack_count.v, 0);
Atomic64FetchAdd(&wave->ack_gen.v, 1);
FutexWakeNeq(&wave->ack_gen.v);
}
}
}
}
void SetWaveLaneDefaultSpin(WaveLaneCtx *lane, u64 n)
{
lane->default_spin_count = n;
}
////////////////////////////////////////////////////////////
//~ Lane helpers
i32 WaveLaneIdxFromTaskIdx(WaveLaneCtx *lane, u64 task_idx)
{
WaveCtx *wave = lane->wave;
return task_idx % wave->lanes_count;
}
RngU64 WaveIdxRangeFromCount(WaveLaneCtx *lane, u64 tasks_count)
{
u64 lanes_count = lane->wave->lanes_count;
u64 lane_idx = lane->idx;
u64 tasks_per_lane = tasks_count / lanes_count;
u64 tasks_overflow = tasks_count % lanes_count;
u64 start = lane_idx * tasks_per_lane;
u64 end = start + tasks_per_lane;
if (lane_idx < tasks_overflow)
{
start += lane_idx;
end += lane_idx + 1;
}
else
{
start += tasks_overflow;
end += tasks_overflow;
}
return RNGU64(start, end);
}