Merge pull request #3979 from ReinUsesLisp/thread-group
shader/other: Implement thread comparisons (NV_shader_thread_group)
This commit is contained in:
commit
487dd05170
|
@ -2309,6 +2309,18 @@ private:
|
||||||
return {"gl_SubGroupInvocationARB", Type::Uint};
|
return {"gl_SubGroupInvocationARB", Type::Uint};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <const std::string_view& comparison>
|
||||||
|
Expression ThreadMask(Operation) {
|
||||||
|
if (device.HasWarpIntrinsics()) {
|
||||||
|
return {fmt::format("gl_Thread{}MaskNV", comparison), Type::Uint};
|
||||||
|
}
|
||||||
|
if (device.HasShaderBallot()) {
|
||||||
|
return {fmt::format("uint(gl_SubGroup{}MaskARB)", comparison), Type::Uint};
|
||||||
|
}
|
||||||
|
LOG_ERROR(Render_OpenGL, "Thread mask intrinsics are required by the shader");
|
||||||
|
return {"0U", Type::Uint};
|
||||||
|
}
|
||||||
|
|
||||||
Expression ShuffleIndexed(Operation operation) {
|
Expression ShuffleIndexed(Operation operation) {
|
||||||
std::string value = VisitOperand(operation, 0).AsFloat();
|
std::string value = VisitOperand(operation, 0).AsFloat();
|
||||||
|
|
||||||
|
@ -2337,6 +2349,12 @@ private:
|
||||||
static constexpr std::string_view NotEqual = "!=";
|
static constexpr std::string_view NotEqual = "!=";
|
||||||
static constexpr std::string_view GreaterEqual = ">=";
|
static constexpr std::string_view GreaterEqual = ">=";
|
||||||
|
|
||||||
|
static constexpr std::string_view Eq = "Eq";
|
||||||
|
static constexpr std::string_view Ge = "Ge";
|
||||||
|
static constexpr std::string_view Gt = "Gt";
|
||||||
|
static constexpr std::string_view Le = "Le";
|
||||||
|
static constexpr std::string_view Lt = "Lt";
|
||||||
|
|
||||||
static constexpr std::string_view Add = "Add";
|
static constexpr std::string_view Add = "Add";
|
||||||
static constexpr std::string_view Min = "Min";
|
static constexpr std::string_view Min = "Min";
|
||||||
static constexpr std::string_view Max = "Max";
|
static constexpr std::string_view Max = "Max";
|
||||||
|
@ -2554,6 +2572,11 @@ private:
|
||||||
&GLSLDecompiler::VoteEqual,
|
&GLSLDecompiler::VoteEqual,
|
||||||
|
|
||||||
&GLSLDecompiler::ThreadId,
|
&GLSLDecompiler::ThreadId,
|
||||||
|
&GLSLDecompiler::ThreadMask<Func::Eq>,
|
||||||
|
&GLSLDecompiler::ThreadMask<Func::Ge>,
|
||||||
|
&GLSLDecompiler::ThreadMask<Func::Gt>,
|
||||||
|
&GLSLDecompiler::ThreadMask<Func::Le>,
|
||||||
|
&GLSLDecompiler::ThreadMask<Func::Lt>,
|
||||||
&GLSLDecompiler::ShuffleIndexed,
|
&GLSLDecompiler::ShuffleIndexed,
|
||||||
|
|
||||||
&GLSLDecompiler::MemoryBarrierGL,
|
&GLSLDecompiler::MemoryBarrierGL,
|
||||||
|
|
|
@ -515,6 +515,16 @@ private:
|
||||||
void DeclareCommon() {
|
void DeclareCommon() {
|
||||||
thread_id =
|
thread_id =
|
||||||
DeclareInputBuiltIn(spv::BuiltIn::SubgroupLocalInvocationId, t_in_uint, "thread_id");
|
DeclareInputBuiltIn(spv::BuiltIn::SubgroupLocalInvocationId, t_in_uint, "thread_id");
|
||||||
|
thread_masks[0] =
|
||||||
|
DeclareInputBuiltIn(spv::BuiltIn::SubgroupEqMask, t_in_uint4, "thread_eq_mask");
|
||||||
|
thread_masks[1] =
|
||||||
|
DeclareInputBuiltIn(spv::BuiltIn::SubgroupGeMask, t_in_uint4, "thread_ge_mask");
|
||||||
|
thread_masks[2] =
|
||||||
|
DeclareInputBuiltIn(spv::BuiltIn::SubgroupGtMask, t_in_uint4, "thread_gt_mask");
|
||||||
|
thread_masks[3] =
|
||||||
|
DeclareInputBuiltIn(spv::BuiltIn::SubgroupLeMask, t_in_uint4, "thread_le_mask");
|
||||||
|
thread_masks[4] =
|
||||||
|
DeclareInputBuiltIn(spv::BuiltIn::SubgroupLtMask, t_in_uint4, "thread_lt_mask");
|
||||||
}
|
}
|
||||||
|
|
||||||
void DeclareVertex() {
|
void DeclareVertex() {
|
||||||
|
@ -2175,6 +2185,13 @@ private:
|
||||||
return {OpLoad(t_uint, thread_id), Type::Uint};
|
return {OpLoad(t_uint, thread_id), Type::Uint};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <std::size_t index>
|
||||||
|
Expression ThreadMask(Operation) {
|
||||||
|
// TODO(Rodrigo): Handle devices with different warp sizes
|
||||||
|
const Id mask = thread_masks[index];
|
||||||
|
return {OpLoad(t_uint, AccessElement(t_in_uint, mask, 0)), Type::Uint};
|
||||||
|
}
|
||||||
|
|
||||||
Expression ShuffleIndexed(Operation operation) {
|
Expression ShuffleIndexed(Operation operation) {
|
||||||
const Id value = AsFloat(Visit(operation[0]));
|
const Id value = AsFloat(Visit(operation[0]));
|
||||||
const Id index = AsUint(Visit(operation[1]));
|
const Id index = AsUint(Visit(operation[1]));
|
||||||
|
@ -2639,6 +2656,11 @@ private:
|
||||||
&SPIRVDecompiler::Vote<&Module::OpSubgroupAllEqualKHR>,
|
&SPIRVDecompiler::Vote<&Module::OpSubgroupAllEqualKHR>,
|
||||||
|
|
||||||
&SPIRVDecompiler::ThreadId,
|
&SPIRVDecompiler::ThreadId,
|
||||||
|
&SPIRVDecompiler::ThreadMask<0>, // Eq
|
||||||
|
&SPIRVDecompiler::ThreadMask<1>, // Ge
|
||||||
|
&SPIRVDecompiler::ThreadMask<2>, // Gt
|
||||||
|
&SPIRVDecompiler::ThreadMask<3>, // Le
|
||||||
|
&SPIRVDecompiler::ThreadMask<4>, // Lt
|
||||||
&SPIRVDecompiler::ShuffleIndexed,
|
&SPIRVDecompiler::ShuffleIndexed,
|
||||||
|
|
||||||
&SPIRVDecompiler::MemoryBarrierGL,
|
&SPIRVDecompiler::MemoryBarrierGL,
|
||||||
|
@ -2763,6 +2785,7 @@ private:
|
||||||
Id workgroup_id{};
|
Id workgroup_id{};
|
||||||
Id local_invocation_id{};
|
Id local_invocation_id{};
|
||||||
Id thread_id{};
|
Id thread_id{};
|
||||||
|
std::array<Id, 5> thread_masks{}; // eq, ge, gt, le, lt
|
||||||
|
|
||||||
VertexIndices in_indices;
|
VertexIndices in_indices;
|
||||||
VertexIndices out_indices;
|
VertexIndices out_indices;
|
||||||
|
|
|
@ -109,6 +109,27 @@ u32 ShaderIR::DecodeOther(NodeBlock& bb, u32 pc) {
|
||||||
return Operation(OperationCode::WorkGroupIdY);
|
return Operation(OperationCode::WorkGroupIdY);
|
||||||
case SystemVariable::CtaIdZ:
|
case SystemVariable::CtaIdZ:
|
||||||
return Operation(OperationCode::WorkGroupIdZ);
|
return Operation(OperationCode::WorkGroupIdZ);
|
||||||
|
case SystemVariable::EqMask:
|
||||||
|
case SystemVariable::LtMask:
|
||||||
|
case SystemVariable::LeMask:
|
||||||
|
case SystemVariable::GtMask:
|
||||||
|
case SystemVariable::GeMask:
|
||||||
|
uses_warps = true;
|
||||||
|
switch (instr.sys20) {
|
||||||
|
case SystemVariable::EqMask:
|
||||||
|
return Operation(OperationCode::ThreadEqMask);
|
||||||
|
case SystemVariable::LtMask:
|
||||||
|
return Operation(OperationCode::ThreadLtMask);
|
||||||
|
case SystemVariable::LeMask:
|
||||||
|
return Operation(OperationCode::ThreadLeMask);
|
||||||
|
case SystemVariable::GtMask:
|
||||||
|
return Operation(OperationCode::ThreadGtMask);
|
||||||
|
case SystemVariable::GeMask:
|
||||||
|
return Operation(OperationCode::ThreadGeMask);
|
||||||
|
default:
|
||||||
|
UNREACHABLE();
|
||||||
|
return Immediate(0u);
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
UNIMPLEMENTED_MSG("Unhandled system move: {}",
|
UNIMPLEMENTED_MSG("Unhandled system move: {}",
|
||||||
static_cast<u32>(instr.sys20.Value()));
|
static_cast<u32>(instr.sys20.Value()));
|
||||||
|
|
|
@ -226,6 +226,11 @@ enum class OperationCode {
|
||||||
VoteEqual, /// (bool) -> bool
|
VoteEqual, /// (bool) -> bool
|
||||||
|
|
||||||
ThreadId, /// () -> uint
|
ThreadId, /// () -> uint
|
||||||
|
ThreadEqMask, /// () -> uint
|
||||||
|
ThreadGeMask, /// () -> uint
|
||||||
|
ThreadGtMask, /// () -> uint
|
||||||
|
ThreadLeMask, /// () -> uint
|
||||||
|
ThreadLtMask, /// () -> uint
|
||||||
ShuffleIndexed, /// (uint value, uint index) -> uint
|
ShuffleIndexed, /// (uint value, uint index) -> uint
|
||||||
|
|
||||||
MemoryBarrierGL, /// () -> void
|
MemoryBarrierGL, /// () -> void
|
||||||
|
|
Reference in New Issue