shader: Add XMAD multiplication folding optimization
This commit is contained in:
parent
4b438f94cf
commit
58914796c0
|
@ -9,6 +9,7 @@
|
||||||
#include "common/bit_cast.h"
|
#include "common/bit_cast.h"
|
||||||
#include "common/bit_util.h"
|
#include "common/bit_util.h"
|
||||||
#include "shader_recompiler/exception.h"
|
#include "shader_recompiler/exception.h"
|
||||||
|
#include "shader_recompiler/frontend/ir/ir_emitter.h"
|
||||||
#include "shader_recompiler/frontend/ir/microinstruction.h"
|
#include "shader_recompiler/frontend/ir/microinstruction.h"
|
||||||
#include "shader_recompiler/ir_opt/passes.h"
|
#include "shader_recompiler/ir_opt/passes.h"
|
||||||
|
|
||||||
|
@ -99,8 +100,71 @@ void FoldGetPred(IR::Inst& inst) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Replaces the pattern generated by two XMAD multiplications
|
||||||
|
bool FoldXmadMultiply(IR::Block& block, IR::Inst& inst) {
|
||||||
|
/*
|
||||||
|
* We are looking for this pattern:
|
||||||
|
* %rhs_bfe = BitFieldUExtract %factor_a, #0, #16 (uses: 1)
|
||||||
|
* %rhs_mul = IMul32 %rhs_bfe, %factor_b (uses: 1)
|
||||||
|
* %lhs_bfe = BitFieldUExtract %factor_a, #16, #16 (uses: 1)
|
||||||
|
* %rhs_mul = IMul32 %lhs_bfe, %factor_b (uses: 1)
|
||||||
|
* %lhs_shl = ShiftLeftLogical32 %rhs_mul, #16 (uses: 1)
|
||||||
|
* %result = IAdd32 %lhs_shl, %rhs_mul (uses: 10)
|
||||||
|
*
|
||||||
|
* And replacing it with
|
||||||
|
* %result = IMul32 %factor_a, %factor_b
|
||||||
|
*
|
||||||
|
* This optimization has been proven safe by LLVM and MSVC.
|
||||||
|
*/
|
||||||
|
const IR::Value lhs_arg{inst.Arg(0)};
|
||||||
|
const IR::Value rhs_arg{inst.Arg(1)};
|
||||||
|
if (lhs_arg.IsImmediate() || rhs_arg.IsImmediate()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
IR::Inst* const lhs_shl{lhs_arg.InstRecursive()};
|
||||||
|
if (lhs_shl->Opcode() != IR::Opcode::ShiftLeftLogical32 || lhs_shl->Arg(1) != IR::Value{16U}) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (lhs_shl->Arg(0).IsImmediate()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
IR::Inst* const lhs_mul{lhs_shl->Arg(0).InstRecursive()};
|
||||||
|
IR::Inst* const rhs_mul{rhs_arg.InstRecursive()};
|
||||||
|
if (lhs_mul->Opcode() != IR::Opcode::IMul32 || rhs_mul->Opcode() != IR::Opcode::IMul32) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (lhs_mul->Arg(1).Resolve() != rhs_mul->Arg(1).Resolve()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
const IR::U32 factor_b{lhs_mul->Arg(1)};
|
||||||
|
if (lhs_mul->Arg(0).IsImmediate() || rhs_mul->Arg(0).IsImmediate()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
IR::Inst* const lhs_bfe{lhs_mul->Arg(0).InstRecursive()};
|
||||||
|
IR::Inst* const rhs_bfe{rhs_mul->Arg(0).InstRecursive()};
|
||||||
|
if (lhs_bfe->Opcode() != IR::Opcode::BitFieldUExtract) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (rhs_bfe->Opcode() != IR::Opcode::BitFieldUExtract) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (lhs_bfe->Arg(1) != IR::Value{16U} || lhs_bfe->Arg(2) != IR::Value{16U}) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (rhs_bfe->Arg(1) != IR::Value{0U} || rhs_bfe->Arg(2) != IR::Value{16U}) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (lhs_bfe->Arg(0).Resolve() != rhs_bfe->Arg(0).Resolve()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
const IR::U32 factor_a{lhs_bfe->Arg(0)};
|
||||||
|
IR::IREmitter ir{block, IR::Block::InstructionList::s_iterator_to(inst)};
|
||||||
|
inst.ReplaceUsesWith(ir.IMul(factor_a, factor_b));
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void FoldAdd(IR::Inst& inst) {
|
void FoldAdd(IR::Block& block, IR::Inst& inst) {
|
||||||
if (inst.HasAssociatedPseudoOperation()) {
|
if (inst.HasAssociatedPseudoOperation()) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -110,6 +174,12 @@ void FoldAdd(IR::Inst& inst) {
|
||||||
const IR::Value rhs{inst.Arg(1)};
|
const IR::Value rhs{inst.Arg(1)};
|
||||||
if (rhs.IsImmediate() && Arg<T>(rhs) == 0) {
|
if (rhs.IsImmediate() && Arg<T>(rhs) == 0) {
|
||||||
inst.ReplaceUsesWith(inst.Arg(0));
|
inst.ReplaceUsesWith(inst.Arg(0));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if constexpr (std::is_same_v<T, u32>) {
|
||||||
|
if (FoldXmadMultiply(block, inst)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -244,14 +314,14 @@ void FoldBranchConditional(IR::Inst& inst) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void ConstantPropagation(IR::Inst& inst) {
|
void ConstantPropagation(IR::Block& block, IR::Inst& inst) {
|
||||||
switch (inst.Opcode()) {
|
switch (inst.Opcode()) {
|
||||||
case IR::Opcode::GetRegister:
|
case IR::Opcode::GetRegister:
|
||||||
return FoldGetRegister(inst);
|
return FoldGetRegister(inst);
|
||||||
case IR::Opcode::GetPred:
|
case IR::Opcode::GetPred:
|
||||||
return FoldGetPred(inst);
|
return FoldGetPred(inst);
|
||||||
case IR::Opcode::IAdd32:
|
case IR::Opcode::IAdd32:
|
||||||
return FoldAdd<u32>(inst);
|
return FoldAdd<u32>(block, inst);
|
||||||
case IR::Opcode::ISub32:
|
case IR::Opcode::ISub32:
|
||||||
return FoldISub32(inst);
|
return FoldISub32(inst);
|
||||||
case IR::Opcode::BitCastF32U32:
|
case IR::Opcode::BitCastF32U32:
|
||||||
|
@ -259,7 +329,7 @@ void ConstantPropagation(IR::Inst& inst) {
|
||||||
case IR::Opcode::BitCastU32F32:
|
case IR::Opcode::BitCastU32F32:
|
||||||
return FoldBitCast<u32, f32>(inst, IR::Opcode::BitCastF32U32);
|
return FoldBitCast<u32, f32>(inst, IR::Opcode::BitCastF32U32);
|
||||||
case IR::Opcode::IAdd64:
|
case IR::Opcode::IAdd64:
|
||||||
return FoldAdd<u64>(inst);
|
return FoldAdd<u64>(block, inst);
|
||||||
case IR::Opcode::Select32:
|
case IR::Opcode::Select32:
|
||||||
return FoldSelect<u32>(inst);
|
return FoldSelect<u32>(inst);
|
||||||
case IR::Opcode::LogicalAnd:
|
case IR::Opcode::LogicalAnd:
|
||||||
|
@ -292,7 +362,9 @@ void ConstantPropagation(IR::Inst& inst) {
|
||||||
} // Anonymous namespace
|
} // Anonymous namespace
|
||||||
|
|
||||||
void ConstantPropagationPass(IR::Block& block) {
|
void ConstantPropagationPass(IR::Block& block) {
|
||||||
std::ranges::for_each(block, ConstantPropagation);
|
for (IR::Inst& inst : block) {
|
||||||
|
ConstantPropagation(block, inst);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace Shader::Optimization
|
} // namespace Shader::Optimization
|
||||||
|
|
Reference in New Issue