1
0
Fork 0

shader_jit: Add optimizations up to `x86-64-v4` (#6668)

This commit is contained in:
Wunk 2023-07-11 09:21:37 -07:00 committed by GitHub
parent 6da4853360
commit a94af8ea62
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 157 additions and 63 deletions

View File

@ -338,15 +338,39 @@ void JitShader::Compile_SanitizedMul(Xmm src1, Xmm src2, Xmm scratch) {
// where neither source was, this NaN was generated by a 0 * inf multiplication, and so the // where neither source was, this NaN was generated by a 0 * inf multiplication, and so the
// result should be transformed to 0 to match PICA fp rules. // result should be transformed to 0 to match PICA fp rules.
if (host_caps.has(Cpu::tAVX512F | Cpu::tAVX512VL | Cpu::tAVX512DQ)) {
vmulps(scratch, src1, src2);
// Mask of any NaN values found in the result
const Xbyak::Opmask zero_mask = k1;
vcmpunordps(zero_mask, scratch, scratch);
// Mask of any non-NaN inputs producing NaN results
vcmpordps(zero_mask | zero_mask, src1, src2);
knotb(zero_mask, zero_mask);
vmovaps(src1 | zero_mask | T_z, scratch);
return;
}
// Set scratch to mask of (src1 != NaN and src2 != NaN) // Set scratch to mask of (src1 != NaN and src2 != NaN)
movaps(scratch, src1); if (host_caps.has(Cpu::tAVX)) {
cmpordps(scratch, src2); vcmpordps(scratch, src1, src2);
} else {
movaps(scratch, src1);
cmpordps(scratch, src2);
}
mulps(src1, src2); mulps(src1, src2);
// Set src2 to mask of (result == NaN) // Set src2 to mask of (result == NaN)
movaps(src2, src1); if (host_caps.has(Cpu::tAVX)) {
cmpunordps(src2, src2); vcmpunordps(src2, src2, src1);
} else {
movaps(src2, src1);
cmpunordps(src2, src2);
}
// Clear components where scratch != src2 (i.e. if result is NaN where neither source was NaN) // Clear components where scratch != src2 (i.e. if result is NaN where neither source was NaN)
xorps(scratch, src2); xorps(scratch, src2);
@ -406,13 +430,20 @@ void JitShader::Compile_DP3(Instruction instr) {
Compile_SanitizedMul(SRC1, SRC2, SCRATCH); Compile_SanitizedMul(SRC1, SRC2, SCRATCH);
movaps(SRC2, SRC1); if (host_caps.has(Cpu::tAVX)) {
shufps(SRC2, SRC2, _MM_SHUFFLE(1, 1, 1, 1)); vshufps(SRC3, SRC1, SRC1, _MM_SHUFFLE(2, 2, 2, 2));
vshufps(SRC2, SRC1, SRC1, _MM_SHUFFLE(1, 1, 1, 1));
vshufps(SRC1, SRC1, SRC1, _MM_SHUFFLE(0, 0, 0, 0));
} else {
movaps(SRC2, SRC1);
shufps(SRC2, SRC2, _MM_SHUFFLE(1, 1, 1, 1));
movaps(SRC3, SRC1); movaps(SRC3, SRC1);
shufps(SRC3, SRC3, _MM_SHUFFLE(2, 2, 2, 2)); shufps(SRC3, SRC3, _MM_SHUFFLE(2, 2, 2, 2));
shufps(SRC1, SRC1, _MM_SHUFFLE(0, 0, 0, 0));
}
shufps(SRC1, SRC1, _MM_SHUFFLE(0, 0, 0, 0));
addps(SRC1, SRC2); addps(SRC1, SRC2);
addps(SRC1, SRC3); addps(SRC1, SRC3);
@ -589,9 +620,15 @@ void JitShader::Compile_MOV(Instruction instr) {
void JitShader::Compile_RCP(Instruction instr) { void JitShader::Compile_RCP(Instruction instr) {
Compile_SwizzleSrc(instr, 1, instr.common.src1, SRC1); Compile_SwizzleSrc(instr, 1, instr.common.src1, SRC1);
// TODO(bunnei): RCPSS is a pretty rough approximation, this might cause problems if Pica if (host_caps.has(Cpu::tAVX512F | Cpu::tAVX512VL)) {
// performs this operation more accurately. This should be checked on hardware. // Accurate to 14 bits of precisions rather than 12 bits of rcpss
rcpss(SRC1, SRC1); vrcp14ss(SRC1, SRC1, SRC1);
} else {
// TODO(bunnei): RCPSS is a pretty rough approximation, this might cause problems if Pica
// performs this operation more accurately. This should be checked on hardware.
rcpss(SRC1, SRC1);
}
shufps(SRC1, SRC1, _MM_SHUFFLE(0, 0, 0, 0)); // XYWZ -> XXXX shufps(SRC1, SRC1, _MM_SHUFFLE(0, 0, 0, 0)); // XYWZ -> XXXX
Compile_DestEnable(instr, SRC1); Compile_DestEnable(instr, SRC1);
@ -600,9 +637,15 @@ void JitShader::Compile_RCP(Instruction instr) {
void JitShader::Compile_RSQ(Instruction instr) { void JitShader::Compile_RSQ(Instruction instr) {
Compile_SwizzleSrc(instr, 1, instr.common.src1, SRC1); Compile_SwizzleSrc(instr, 1, instr.common.src1, SRC1);
// TODO(bunnei): RSQRTSS is a pretty rough approximation, this might cause problems if Pica if (host_caps.has(Cpu::tAVX512F | Cpu::tAVX512VL)) {
// performs this operation more accurately. This should be checked on hardware. // Accurate to 14 bits of precisions rather than 12 bits of rsqrtss
rsqrtss(SRC1, SRC1); vrsqrt14ss(SRC1, SRC1, SRC1);
} else {
// TODO(bunnei): RSQRTSS is a pretty rough approximation, this might cause problems if Pica
// performs this operation more accurately. This should be checked on hardware.
rsqrtss(SRC1, SRC1);
}
shufps(SRC1, SRC1, _MM_SHUFFLE(0, 0, 0, 0)); // XYWZ -> XXXX shufps(SRC1, SRC1, _MM_SHUFFLE(0, 0, 0, 0)); // XYWZ -> XXXX
Compile_DestEnable(instr, SRC1); Compile_DestEnable(instr, SRC1);
@ -1050,32 +1093,47 @@ Xbyak::Label JitShader::CompilePrelude_Log2() {
jp(input_is_nan); jp(input_is_nan);
jae(input_out_of_range); jae(input_out_of_range);
// Split input // Split input: SRC1=MANT[1,2) SCRATCH2=Exponent
movd(eax, SRC1); if (host_caps.has(Cpu::tAVX512F | Cpu::tAVX512VL)) {
mov(edx, eax); vgetexpss(SCRATCH2, SRC1, SRC1);
and_(eax, 0x7f800000); vgetmantss(SRC1, SRC1, SRC1, 0x0'0);
and_(edx, 0x007fffff); } else {
movss(SCRATCH, xword[rip + c0]); // Preload c0. movd(eax, SRC1);
or_(edx, 0x3f800000); mov(edx, eax);
movd(SRC1, edx); and_(eax, 0x7f800000);
// SRC1 now contains the mantissa of the input. and_(edx, 0x007fffff);
mulss(SCRATCH, SRC1); or_(edx, 0x3f800000);
shr(eax, 23); movd(SRC1, edx);
sub(eax, 0x7f); // SRC1 now contains the mantissa of the input.
cvtsi2ss(SCRATCH2, eax); shr(eax, 23);
// SCRATCH2 now contains the exponent of the input. sub(eax, 0x7f);
cvtsi2ss(SCRATCH2, eax);
// SCRATCH2 now contains the exponent of the input.
}
movss(SCRATCH, xword[rip + c0]);
// Complete computation of polynomial // Complete computation of polynomial
addss(SCRATCH, xword[rip + c1]); if (host_caps.has(Cpu::tFMA)) {
mulss(SCRATCH, SRC1); vfmadd213ss(SCRATCH, SRC1, xword[rip + c1]);
addss(SCRATCH, xword[rip + c2]); vfmadd213ss(SCRATCH, SRC1, xword[rip + c2]);
mulss(SCRATCH, SRC1); vfmadd213ss(SCRATCH, SRC1, xword[rip + c3]);
addss(SCRATCH, xword[rip + c3]); vfmadd213ss(SCRATCH, SRC1, xword[rip + c4]);
mulss(SCRATCH, SRC1); subss(SRC1, ONE);
subss(SRC1, ONE); vfmadd231ss(SCRATCH2, SCRATCH, SRC1);
addss(SCRATCH, xword[rip + c4]); } else {
mulss(SCRATCH, SRC1); mulss(SCRATCH, SRC1);
addss(SCRATCH2, SCRATCH); addss(SCRATCH, xword[rip + c1]);
mulss(SCRATCH, SRC1);
addss(SCRATCH, xword[rip + c2]);
mulss(SCRATCH, SRC1);
addss(SCRATCH, xword[rip + c3]);
mulss(SCRATCH, SRC1);
subss(SRC1, ONE);
addss(SCRATCH, xword[rip + c4]);
mulss(SCRATCH, SRC1);
addss(SCRATCH2, SCRATCH);
}
// Duplicate result across vector // Duplicate result across vector
xorps(SRC1, SRC1); // break dependency chain xorps(SRC1, SRC1); // break dependency chain
@ -1122,33 +1180,69 @@ Xbyak::Label JitShader::CompilePrelude_Exp2() {
// Handle edge cases // Handle edge cases
ucomiss(SRC1, SRC1); ucomiss(SRC1, SRC1);
jp(ret_label); jp(ret_label);
// Clamp to maximum range since we shift the value directly into the exponent.
minss(SRC1, xword[rip + input_max]);
maxss(SRC1, xword[rip + input_min]);
// Decompose input // Decompose input:
movss(SCRATCH, SRC1); // SCRATCH=2^round(input)
movss(SCRATCH2, xword[rip + c0]); // Preload c0. // SRC1=input-round(input) [-0.5, 0.5)
subss(SCRATCH, xword[rip + half]); if (host_caps.has(Cpu::tAVX512F | Cpu::tAVX512VL)) {
cvtss2si(eax, SCRATCH); // input - 0.5
cvtsi2ss(SCRATCH, eax); vsubss(SCRATCH, SRC1, xword[rip + half]);
// SCRATCH now contains input rounded to the nearest integer.
add(eax, 0x7f); // trunc(input - 0.5)
subss(SRC1, SCRATCH); vrndscaless(SCRATCH2, SCRATCH, SCRATCH, _MM_FROUND_TRUNC);
// SRC1 contains input - round(input), which is in [-0.5, 0.5).
mulss(SCRATCH2, SRC1); // SCRATCH = 1 * 2^(trunc(input - 0.5))
shl(eax, 23); vscalefss(SCRATCH, ONE, SCRATCH2);
movd(SCRATCH, eax);
// SCRATCH contains 2^(round(input)). // SRC1 = input-trunc(input - 0.5)
vsubss(SRC1, SRC1, SCRATCH2);
} else {
// Clamp to maximum range since we shift the value directly into the exponent.
minss(SRC1, xword[rip + input_max]);
maxss(SRC1, xword[rip + input_min]);
if (host_caps.has(Cpu::tAVX)) {
vsubss(SCRATCH, SRC1, xword[rip + half]);
} else {
movss(SCRATCH, SRC1);
subss(SCRATCH, xword[rip + half]);
}
if (host_caps.has(Cpu::tSSE41)) {
roundss(SCRATCH, SCRATCH, _MM_FROUND_TRUNC);
cvtss2si(eax, SCRATCH);
} else {
cvtss2si(eax, SCRATCH);
cvtsi2ss(SCRATCH, eax);
}
// SCRATCH now contains input rounded to the nearest integer.
add(eax, 0x7f);
subss(SRC1, SCRATCH);
// SRC1 contains input - round(input), which is in [-0.5, 0.5).
shl(eax, 23);
movd(SCRATCH, eax);
// SCRATCH contains 2^(round(input)).
}
// Complete computation of polynomial. // Complete computation of polynomial.
addss(SCRATCH2, xword[rip + c1]); movss(SCRATCH2, xword[rip + c0]);
mulss(SCRATCH2, SRC1);
addss(SCRATCH2, xword[rip + c2]); if (host_caps.has(Cpu::tFMA)) {
mulss(SCRATCH2, SRC1); vfmadd213ss(SCRATCH2, SRC1, xword[rip + c1]);
addss(SCRATCH2, xword[rip + c3]); vfmadd213ss(SCRATCH2, SRC1, xword[rip + c2]);
mulss(SRC1, SCRATCH2); vfmadd213ss(SCRATCH2, SRC1, xword[rip + c3]);
addss(SRC1, xword[rip + c4]); vfmadd213ss(SRC1, SCRATCH2, xword[rip + c4]);
} else {
mulss(SCRATCH2, SRC1);
addss(SCRATCH2, xword[rip + c1]);
mulss(SCRATCH2, SRC1);
addss(SCRATCH2, xword[rip + c2]);
mulss(SCRATCH2, SRC1);
addss(SCRATCH2, xword[rip + c3]);
mulss(SRC1, SCRATCH2);
addss(SRC1, xword[rip + c4]);
}
mulss(SRC1, SCRATCH); mulss(SRC1, SCRATCH);
// Duplicate result across vector // Duplicate result across vector