1
0
Fork 0

video_core: Implement an arm64 shader-jit backend (#7002)

* externals: Add oaksim submodule

Used for emitting ARM64 assembly

* common: Implement aarch64 ABI

Utilize oaknut to implement a stack frame.

* tests: Allow shader-jit tests for x64 and a64

Run the shader-jit tests for both x86_64 and arm64 targets

* video_core: Initialize arm64 shader-jit backend

Passes all current unit tests!

* shader_jit_a64: protect/unprotect memory when jit-ing

Required on MacOS. Memory needs to be fully unprotected and then
re-protected when writing or there will be memory access errors on
MacOS.

* shader_jit_a64: Fix ARM64-Imm overflow

These conditionals were throwing exceptions since the immediate values
were overflowing the available space in the `EOR` instructions. Instead
they are generated from `MOV` and then `EOR`-ed after.

* shader_jit_a64: Fix Geometry shader conditional

* shader_jit_a64: Replace `ADRL` with `MOVP2R`

Fixes some immediate-generation exceptions.

* common/aarch64: Fix CallFarFunction

* shader_jit_a64: Optimize `SantitizedMul`

Co-authored-by: merryhime <merryhime@users.noreply.github.com>

* shader_jit_a64: Fix address register offset behavior

Based on https://github.com/citra-emu/citra/pull/6942
Passes unit tests.

* shader_jit_a64: Fix `RET` address offset

A64 stack is 16-byte aligned rather than 8. So a direct port of the x64
code won't work. Fixes weird branches into invalid memory for any
shaders with subroutines.

* shader_jit_a64: Increase max program size

Tuned for A64 program size.

* shader_jit_a64: Use `UBFX` for extracting loop-state

Co-authored-by: JosJuice <JosJuice@users.noreply.github.com>

* shader_jit_a64: Optimize `SUB+CMP` to `SUBS`

Co-authored-by: JosJuice <JosJuice@users.noreply.github.com>

* shader_jit_a64: Optimize `CMP+B` to `CBNZ`

Co-authored-by: JosJuice <JosJuice@users.noreply.github.com>

* shader_jit_a64: Use `FMOV` for `ONE` vector

Co-authored-by: JosJuice <JosJuice@users.noreply.github.com>

* shader_jit_a64: Remove x86-specific documentation

* shader_jit_a64: Use `UBFX` to extract exponent

Co-authored-by: JosJuice <JosJuice@users.noreply.github.com>

* shader_jit_a64: Remove redundant MIN/MAX `SRC2`-NaN check

Special handling only needs to check SRC1 for NaN, not SRC2.
It would work as follows in the four possible cases:

No NaN: No special handling needed.
Only SRC1 is NaN: The special handling is triggered because SRC1 is NaN, and SRC2 is picked.
Only SRC2 is NaN: FMAX automatically picks SRC2 because it always picks the NaN if there is one.
Both SRC1 and SRC2 are NaN: The special handling is triggered because SRC1 is NaN, and SRC2 is picked.

Co-authored-by: JosJuice <JosJuice@users.noreply.github.com>

* shader_jit/tests:: Add catch-stringifier for vec2f/vec3f

* shader_jit/tests: Add Dest Mask unit test

* shader_jit_a64: Fix Dest-Mask `BSL` operand order

Passes the dest-mask unit tests now.

* shader_jit_a64: Use `MOVI` for DestEnable mask

Accelerate certain cases of masking with MOVI as well

Co-authored-by: JosJuice <JosJuice@users.noreply.github.com>

* shader_jit/tests: Add source-swizzle unit test

This is not expansive. Generating all `4^4` cases seems to make Catch2
crash. So I've added some component-masking(non-reordering) tests based
on the Dest-Mask unit-test and some additional ones to test
broadcasts/splats and component re-ordering.

* shader_jit_a64: Fix swizzle index generation

This was still generating `SHUFPS` indices and not the ones that we wanted for the `TBL` instruction. Passes all unit tests now.

* shader_jit/tests: Add `ShaderSetup` constructor to `ShaderTest`

Rather than using the direct output of `CompileShaderSetup` allow a
`ShaderSetup` object to be passed in directly.  This enabled the ability
emit assembly that is not directly supported by nihstro.

* shader_jit/tests: Add `CALL` unit-test

Tests nested `CALL` instructions to eventually reach an `EX2`
instruction.

EX2 is picked in particular since it is implemented as an even deeper
dispatch and ensures subroutines are properly implemented between `CALL`
instructions and implementation-calls.

* shader_jit_a64: Fix nested `BL` subroutines

`lr` was getting writen over by nested calls to `BL`, causing undefined
behavior with mixtures of `CALL`, `EX2`, and `LG2` instructions.

Each usage of `BL` is now protected with a stach push/pop to preserve
and restore teh `lr` register to allow nested subroutines to work
properly.

* shader_jit/tests: Allocate generated tests on heap

Each of these generated shader-test objects were causing the stack to
overflow.  Allocate each of the generated tests on the heap and use
unique_ptr so they only exist within the life-time of the `REQUIRE`
statement.

* shader_jit_a64: Preserve `lr` register from external function calls

`EMIT` makes an external function call, and should be preserving `lr`

* shader_jit/tests: Add `MAD` unit-test

The Inline Asm version requires an upstream fix:
https://github.com/neobrain/nihstro/issues/68

Instead, the program code is manually configured and added.

* shader_jit/tests: Fix uninitialized instructions

These `union`-type instruction-types were uninitialized, causing tests
to indeterminantly fail at times.

* shader_jit_a64: Remove unneeded `MOV`

Residue from the direct-port of x64 code.

* shader_jit_a64: Use `std::array` for `instr_table`

Add some type-safety and const-correctness around this type as well.

* shader_jit_a64: Avoid c-style offset casting

Add some more const-correctness to this function as well.

* video_core: Add arch preprocessor comments

* common/aarch64: Use X16 as the veneer register

https://developer.arm.com/documentation/102374/0101/Procedure-Call-Standard

* shader_jit/tests: Add uniform reading unit-test

Particularly to ensure that addresses are being properly truncated

* common/aarch64: Use `X0` as `ABI_RETURN`

`X8` is used as the indirect return result value in the case that the
result is bigger than 128-bits. Principally `X0` is the general-case
return register though.

* common/aarch64: Add veneer register note

`LR` is generally overwritten by `BLR` anyways, and would also be a safe
veneer to utilize for far-calls.

* shader_jit_a64: Remove unneeded scratch register from `SanitizedMul`

* shader_jit_a64: Fix CALLU condition

Should be `EQ` not `NE`. Fixes the regression on Kid Icarus.
No known regressions anymore!

---------

Co-authored-by: merryhime <merryhime@users.noreply.github.com>
Co-authored-by: JosJuice <JosJuice@users.noreply.github.com>
This commit is contained in:
Wunk 2023-11-05 12:40:31 -08:00 committed by GitHub
parent 3218af38d0
commit e13735b624
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 1874 additions and 25 deletions

3
.gitmodules vendored
View File

@ -88,3 +88,6 @@
[submodule "libadrenotools"] [submodule "libadrenotools"]
path = externals/libadrenotools path = externals/libadrenotools
url = https://github.com/bylaws/libadrenotools url = https://github.com/bylaws/libadrenotools
[submodule "oaknut"]
path = externals/oaknut
url = https://github.com/merryhime/oaknut.git

View File

@ -85,6 +85,11 @@ if ("x86_64" IN_LIST ARCHITECTURE)
endif() endif()
endif() endif()
# Oaknut
if ("arm64" IN_LIST ARCHITECTURE)
add_subdirectory(oaknut EXCLUDE_FROM_ALL)
endif()
# Dynarmic # Dynarmic
if ("x86_64" IN_LIST ARCHITECTURE OR "arm64" IN_LIST ARCHITECTURE) if ("x86_64" IN_LIST ARCHITECTURE OR "arm64" IN_LIST ARCHITECTURE)
if(USE_SYSTEM_DYNARMIC) if(USE_SYSTEM_DYNARMIC)

1
externals/oaknut vendored Submodule

@ -0,0 +1 @@
Subproject commit e6eecc3f9460728be0a8d3f63e66d31c0362f472

View File

@ -53,6 +53,8 @@ add_custom_command(OUTPUT scm_rev.cpp
add_library(citra_common STATIC add_library(citra_common STATIC
aarch64/cpu_detect.cpp aarch64/cpu_detect.cpp
aarch64/cpu_detect.h aarch64/cpu_detect.h
aarch64/oaknut_abi.h
aarch64/oaknut_util.h
alignment.h alignment.h
android_storage.h android_storage.h
android_storage.cpp android_storage.cpp
@ -179,6 +181,10 @@ if ("x86_64" IN_LIST ARCHITECTURE)
target_link_libraries(citra_common PRIVATE xbyak) target_link_libraries(citra_common PRIVATE xbyak)
endif() endif()
if ("arm64" IN_LIST ARCHITECTURE)
target_link_libraries(citra_common PRIVATE oaknut)
endif()
if (CITRA_USE_PRECOMPILED_HEADERS) if (CITRA_USE_PRECOMPILED_HEADERS)
target_precompile_headers(citra_common PRIVATE precompiled_headers.h) target_precompile_headers(citra_common PRIVATE precompiled_headers.h)
endif() endif()

View File

@ -0,0 +1,155 @@
// Copyright 2023 Citra Emulator Project
// Licensed under GPLv2 or any later version
// Refer to the license.txt file included.
#pragma once
#include "common/arch.h"
#if CITRA_ARCH(arm64)
#include <bitset>
#include <initializer_list>
#include <oaknut/oaknut.hpp>
#include "common/assert.h"
namespace Common::A64 {
constexpr std::size_t RegToIndex(const oaknut::Reg& reg) {
ASSERT(reg.index() != 31); // ZR not allowed
return reg.index() + (reg.is_vector() ? 32 : 0);
}
constexpr oaknut::XReg IndexToXReg(std::size_t reg_index) {
ASSERT(reg_index <= 30);
return oaknut::XReg(static_cast<int>(reg_index));
}
constexpr oaknut::VReg IndexToVReg(std::size_t reg_index) {
ASSERT(reg_index >= 32 && reg_index < 64);
return oaknut::QReg(static_cast<int>(reg_index - 32));
}
constexpr oaknut::Reg IndexToReg(std::size_t reg_index) {
if (reg_index < 32) {
return IndexToXReg(reg_index);
} else {
return IndexToVReg(reg_index);
}
}
inline constexpr std::bitset<64> BuildRegSet(std::initializer_list<oaknut::Reg> regs) {
std::bitset<64> bits;
for (const oaknut::Reg& reg : regs) {
bits.set(RegToIndex(reg));
}
return bits;
}
constexpr inline std::bitset<64> ABI_ALL_GPRS(0x00000000'7FFFFFFF);
constexpr inline std::bitset<64> ABI_ALL_FPRS(0xFFFFFFFF'00000000);
constexpr inline oaknut::XReg ABI_RETURN = oaknut::util::X0;
constexpr inline oaknut::XReg ABI_PARAM1 = oaknut::util::X0;
constexpr inline oaknut::XReg ABI_PARAM2 = oaknut::util::X1;
constexpr inline oaknut::XReg ABI_PARAM3 = oaknut::util::X2;
constexpr inline oaknut::XReg ABI_PARAM4 = oaknut::util::X3;
constexpr std::bitset<64> ABI_ALL_CALLER_SAVED = 0xffffffff'4000ffff;
constexpr std::bitset<64> ABI_ALL_CALLEE_SAVED = 0x0000ff00'7ff80000;
struct ABIFrameInfo {
u32 subtraction;
u32 fprs_offset;
};
inline ABIFrameInfo ABI_CalculateFrameSize(std::bitset<64> regs, std::size_t frame_size) {
const size_t gprs_count = (regs & ABI_ALL_GPRS).count();
const size_t fprs_count = (regs & ABI_ALL_FPRS).count();
const size_t gprs_size = (gprs_count + 1) / 2 * 16;
const size_t fprs_size = fprs_count * 16;
size_t total_size = 0;
total_size += gprs_size;
const size_t fprs_base_subtraction = total_size;
total_size += fprs_size;
total_size += frame_size;
return ABIFrameInfo{static_cast<u32>(total_size), static_cast<u32>(fprs_base_subtraction)};
}
inline void ABI_PushRegisters(oaknut::CodeGenerator& code, std::bitset<64> regs,
std::size_t frame_size = 0) {
using namespace oaknut;
using namespace oaknut::util;
auto frame_info = ABI_CalculateFrameSize(regs, frame_size);
// Allocate stack-space
if (frame_info.subtraction != 0) {
code.SUB(SP, SP, frame_info.subtraction);
}
// TODO(wunk): Push pairs of registers at a time with STP
std::size_t offset = 0;
for (std::size_t i = 0; i < 32; ++i) {
if (regs[i] && ABI_ALL_GPRS[i]) {
const XReg reg = IndexToXReg(i);
code.STR(reg, SP, offset);
offset += 8;
}
}
offset = 0;
for (std::size_t i = 32; i < 64; ++i) {
if (regs[i] && ABI_ALL_FPRS[i]) {
const VReg reg = IndexToVReg(i);
code.STR(reg.toQ(), SP, u16(frame_info.fprs_offset + offset));
offset += 16;
}
}
// Allocate frame-space
if (frame_size != 0) {
code.SUB(SP, SP, frame_size);
}
}
inline void ABI_PopRegisters(oaknut::CodeGenerator& code, std::bitset<64> regs,
std::size_t frame_size = 0) {
using namespace oaknut;
using namespace oaknut::util;
auto frame_info = ABI_CalculateFrameSize(regs, frame_size);
// Free frame-space
if (frame_size != 0) {
code.ADD(SP, SP, frame_size);
}
// TODO(wunk): Pop pairs of registers at a time with LDP
std::size_t offset = 0;
for (std::size_t i = 0; i < 32; ++i) {
if (regs[i] && ABI_ALL_GPRS[i]) {
const XReg reg = IndexToXReg(i);
code.LDR(reg, SP, offset);
offset += 8;
}
}
offset = 0;
for (std::size_t i = 32; i < 64; ++i) {
if (regs[i] && ABI_ALL_FPRS[i]) {
const VReg reg = IndexToVReg(i);
code.LDR(reg.toQ(), SP, frame_info.fprs_offset + offset);
offset += 16;
}
}
// Free stack-space
if (frame_info.subtraction != 0) {
code.ADD(SP, SP, frame_info.subtraction);
}
}
} // namespace Common::A64
#endif // CITRA_ARCH(arm64)

View File

@ -0,0 +1,43 @@
// Copyright 2023 Citra Emulator Project
// Licensed under GPLv2 or any later version
// Refer to the license.txt file included.
#pragma once
#include "common/arch.h"
#if CITRA_ARCH(arm64)
#include <type_traits>
#include <oaknut/oaknut.hpp>
#include "common/aarch64/oaknut_abi.h"
namespace Common::A64 {
// BL can only reach targets within +-128MiB(24 bits)
inline bool IsWithin128M(uintptr_t ref, uintptr_t target) {
const u64 distance = target - (ref + 4);
return !(distance >= 0x800'0000ULL && distance <= ~0x800'0000ULL);
}
inline bool IsWithin128M(const oaknut::CodeGenerator& code, uintptr_t target) {
return IsWithin128M(code.ptr<uintptr_t>(), target);
}
template <typename T>
inline void CallFarFunction(oaknut::CodeGenerator& code, const T f) {
static_assert(std::is_pointer_v<T>, "Argument must be a (function) pointer.");
const std::uintptr_t addr = reinterpret_cast<std::uintptr_t>(f);
if (IsWithin128M(code, addr)) {
code.BL(reinterpret_cast<const void*>(f));
} else {
// X16(IP0) and X17(IP1) is the standard veneer register
// LR is also available as an intermediate register
// https://developer.arm.com/documentation/102374/0101/Procedure-Call-Standard
code.MOVP2R(oaknut::util::X16, reinterpret_cast<const void*>(f));
code.BLR(oaknut::util::X16);
}
}
} // namespace Common::A64
#endif // CITRA_ARCH(arm64)

View File

@ -15,7 +15,7 @@ add_executable(tests
audio_core/lle/lle.cpp audio_core/lle/lle.cpp
audio_core/audio_fixures.h audio_core/audio_fixures.h
audio_core/decoder_tests.cpp audio_core/decoder_tests.cpp
video_core/shader/shader_jit_x64_compiler.cpp video_core/shader/shader_jit_compiler.cpp
) )
create_target_directory_groups(tests) create_target_directory_groups(tests)

View File

@ -1,9 +1,9 @@
// Copyright 2017 Citra Emulator Project // Copyright 2023 Citra Emulator Project
// Licensed under GPLv2 or any later version // Licensed under GPLv2 or any later version
// Refer to the license.txt file included. // Refer to the license.txt file included.
#include "common/arch.h" #include "common/arch.h"
#if CITRA_ARCH(x86_64) #if CITRA_ARCH(x86_64) || CITRA_ARCH(arm64)
#include <algorithm> #include <algorithm>
#include <cmath> #include <cmath>
@ -14,7 +14,11 @@
#include <fmt/format.h> #include <fmt/format.h>
#include <nihstro/inline_assembly.h> #include <nihstro/inline_assembly.h>
#include "video_core/shader/shader_interpreter.h" #include "video_core/shader/shader_interpreter.h"
#if CITRA_ARCH(x86_64)
#include "video_core/shader/shader_jit_x64_compiler.h" #include "video_core/shader/shader_jit_x64_compiler.h"
#elif CITRA_ARCH(arm64)
#include "video_core/shader/shader_jit_a64_compiler.h"
#endif
using JitShader = Pica::Shader::JitShader; using JitShader = Pica::Shader::JitShader;
using ShaderInterpreter = Pica::Shader::InterpreterEngine; using ShaderInterpreter = Pica::Shader::InterpreterEngine;
@ -31,6 +35,18 @@ static constexpr Common::Vec4f vec4_zero = Common::Vec4f::AssignToAll(0.0f);
namespace Catch { namespace Catch {
template <> template <>
struct StringMaker<Common::Vec2f> {
static std::string convert(Common::Vec2f value) {
return fmt::format("({}, {})", value.x, value.y);
}
};
template <>
struct StringMaker<Common::Vec3f> {
static std::string convert(Common::Vec3f value) {
return fmt::format("({}, {}, {})", value.r(), value.g(), value.b());
}
};
template <>
struct StringMaker<Common::Vec4f> { struct StringMaker<Common::Vec4f> {
static std::string convert(Common::Vec4f value) { static std::string convert(Common::Vec4f value) {
return fmt::format("({}, {}, {}, {})", value.r(), value.g(), value.b(), value.a()); return fmt::format("({}, {}, {}, {})", value.r(), value.g(), value.b(), value.a());
@ -59,6 +75,11 @@ public:
shader_jit.Compile(&shader_setup->program_code, &shader_setup->swizzle_data); shader_jit.Compile(&shader_setup->program_code, &shader_setup->swizzle_data);
} }
explicit ShaderTest(std::unique_ptr<Pica::Shader::ShaderSetup> input_shader_setup)
: shader_setup(std::move(input_shader_setup)) {
shader_jit.Compile(&shader_setup->program_code, &shader_setup->swizzle_data);
}
Common::Vec4f Run(std::span<const Common::Vec4f> inputs) { Common::Vec4f Run(std::span<const Common::Vec4f> inputs) {
Pica::Shader::UnitState shader_unit; Pica::Shader::UnitState shader_unit;
RunJit(shader_unit, inputs); RunJit(shader_unit, inputs);
@ -144,6 +165,41 @@ TEST_CASE("ADD", "[video_core][shader][shader_jit]") {
REQUIRE(std::isinf(shader.Run({INFINITY, -1.0f}).x)); REQUIRE(std::isinf(shader.Run({INFINITY, -1.0f}).x));
} }
TEST_CASE("CALL", "[video_core][shader][shader_jit]") {
const auto sh_input = SourceRegister::MakeInput(0);
const auto sh_output = DestRegister::MakeOutput(0);
auto shader_setup = CompileShaderSetup({
{OpCode::Id::NOP}, // call foo
{OpCode::Id::END},
// .proc foo
{OpCode::Id::NOP}, // call ex2
{OpCode::Id::END},
// .proc ex2
{OpCode::Id::EX2, sh_output, sh_input},
{OpCode::Id::END},
});
// nihstro does not support the CALL* instructions, so the instruction-binary must be manually
// inserted here:
nihstro::Instruction CALL = {};
CALL.opcode = nihstro::OpCode(nihstro::OpCode::Id::CALL);
// call foo
CALL.flow_control.dest_offset = 2;
CALL.flow_control.num_instructions = 1;
shader_setup->program_code[0] = CALL.hex;
// call ex2
CALL.flow_control.dest_offset = 4;
CALL.flow_control.num_instructions = 1;
shader_setup->program_code[2] = CALL.hex;
auto shader = ShaderTest(std::move(shader_setup));
REQUIRE(shader.Run(0.f).x == Catch::Approx(1.f));
}
TEST_CASE("DP3", "[video_core][shader][shader_jit]") { TEST_CASE("DP3", "[video_core][shader][shader_jit]") {
const auto sh_input1 = SourceRegister::MakeInput(0); const auto sh_input1 = SourceRegister::MakeInput(0);
const auto sh_input2 = SourceRegister::MakeInput(1); const auto sh_input2 = SourceRegister::MakeInput(1);
@ -395,6 +451,39 @@ TEST_CASE("RSQ", "[video_core][shader][shader_jit]") {
REQUIRE(shader.Run({0.0625f}).x == Catch::Approx(4.0f).margin(0.004f)); REQUIRE(shader.Run({0.0625f}).x == Catch::Approx(4.0f).margin(0.004f));
} }
TEST_CASE("Uniform Read", "[video_core][shader][shader_jit]") {
const auto sh_input = SourceRegister::MakeInput(0);
const auto sh_c0 = SourceRegister::MakeFloat(0);
const auto sh_output = DestRegister::MakeOutput(0);
auto shader = ShaderTest({
// mova a0.x, sh_input.x
{OpCode::Id::MOVA, DestRegister{}, "x", sh_input, "x", SourceRegister{}, "",
nihstro::InlineAsm::RelativeAddress::A1},
// mov sh_output.xyzw, c0[a0.x].xyzw
{OpCode::Id::MOV, sh_output, "xyzw", sh_c0, "xyzw", SourceRegister{}, "",
nihstro::InlineAsm::RelativeAddress::A1},
{OpCode::Id::END},
});
// Prepare shader uniforms
std::array<Common::Vec4f, 96> f_uniforms = {};
for (u32 i = 0; i < 96; ++i) {
const float color = (i * 2.0f) / 255.0f;
const auto color_f24 = Pica::f24::FromFloat32(color);
shader.shader_setup->uniforms.f[i] = {color_f24, color_f24, color_f24, Pica::f24::One()};
f_uniforms[i] = {color, color, color, 1.0f};
}
for (u32 i = 0; i < 96; ++i) {
const float index = static_cast<float>(i);
// Add some fractional values to test proper float->integer truncation
const float fractional = (i % 17) / 17.0f;
REQUIRE(shader.Run(index + fractional) == f_uniforms[i]);
}
}
TEST_CASE("Address Register Offset", "[video_core][shader][shader_jit]") { TEST_CASE("Address Register Offset", "[video_core][shader][shader_jit]") {
const auto sh_input = SourceRegister::MakeInput(0); const auto sh_input = SourceRegister::MakeInput(0);
const auto sh_c40 = SourceRegister::MakeFloat(40); const auto sh_c40 = SourceRegister::MakeFloat(40);
@ -445,23 +534,83 @@ TEST_CASE("Address Register Offset", "[video_core][shader][shader_jit]") {
REQUIRE(shader.Run(-129.f) == f_uniforms[40]); REQUIRE(shader.Run(-129.f) == f_uniforms[40]);
} }
TEST_CASE("Dest Mask", "[video_core][shader][shader_jit]") {
const auto sh_input = SourceRegister::MakeInput(0);
const auto sh_output = DestRegister::MakeOutput(0);
const auto shader = [&sh_input, &sh_output](const char* dest_mask) {
return std::unique_ptr<ShaderTest>(new ShaderTest{
{OpCode::Id::MOV, sh_output, dest_mask, sh_input, "xyzw", SourceRegister{}, ""},
{OpCode::Id::END},
});
};
const Common::Vec4f iota_vec = {1.0f, 2.0f, 3.0f, 4.0f};
REQUIRE(shader("x")->Run({iota_vec}).x == iota_vec.x);
REQUIRE(shader("y")->Run({iota_vec}).y == iota_vec.y);
REQUIRE(shader("z")->Run({iota_vec}).z == iota_vec.z);
REQUIRE(shader("w")->Run({iota_vec}).w == iota_vec.w);
REQUIRE(shader("xy")->Run({iota_vec}).xy() == iota_vec.xy());
REQUIRE(shader("xz")->Run({iota_vec}).xz() == iota_vec.xz());
REQUIRE(shader("xw")->Run({iota_vec}).xw() == iota_vec.xw());
REQUIRE(shader("yz")->Run({iota_vec}).yz() == iota_vec.yz());
REQUIRE(shader("yw")->Run({iota_vec}).yw() == iota_vec.yw());
REQUIRE(shader("zw")->Run({iota_vec}).zw() == iota_vec.zw());
REQUIRE(shader("xyz")->Run({iota_vec}).xyz() == iota_vec.xyz());
REQUIRE(shader("xyw")->Run({iota_vec}).xyw() == iota_vec.xyw());
REQUIRE(shader("xzw")->Run({iota_vec}).xzw() == iota_vec.xzw());
REQUIRE(shader("yzw")->Run({iota_vec}).yzw() == iota_vec.yzw());
REQUIRE(shader("xyzw")->Run({iota_vec}) == iota_vec);
}
TEST_CASE("MAD", "[video_core][shader][shader_jit]") {
const auto sh_input1 = SourceRegister::MakeInput(0);
const auto sh_input2 = SourceRegister::MakeInput(1);
const auto sh_input3 = SourceRegister::MakeInput(2);
const auto sh_output = DestRegister::MakeOutput(0);
auto shader_setup = CompileShaderSetup({
// TODO: Requires fix from https://github.com/neobrain/nihstro/issues/68 // TODO: Requires fix from https://github.com/neobrain/nihstro/issues/68
// TEST_CASE("MAD", "[video_core][shader][shader_jit]") {
// const auto sh_input1 = SourceRegister::MakeInput(0);
// const auto sh_input2 = SourceRegister::MakeInput(1);
// const auto sh_input3 = SourceRegister::MakeInput(2);
// const auto sh_output = DestRegister::MakeOutput(0);
// auto shader = ShaderTest({
// {OpCode::Id::MAD, sh_output, sh_input1, sh_input2, sh_input3}, // {OpCode::Id::MAD, sh_output, sh_input1, sh_input2, sh_input3},
// {OpCode::Id::END}, {OpCode::Id::NOP},
// }); {OpCode::Id::END},
});
// REQUIRE(shader.Run({vec4_inf, vec4_zero, vec4_zero}).x == 0.0f); // nihstro does not support the MAD* instructions, so the instruction-binary must be manually
// REQUIRE(std::isnan(shader.Run({vec4_nan, vec4_zero, vec4_zero}).x)); // inserted here:
nihstro::Instruction MAD = {};
MAD.opcode = nihstro::OpCode::Id::MAD;
MAD.mad.operand_desc_id = 0;
MAD.mad.src1 = sh_input1;
MAD.mad.src2 = sh_input2;
MAD.mad.src3 = sh_input3;
MAD.mad.dest = sh_output;
shader_setup->program_code[0] = MAD.hex;
// REQUIRE(shader.Run({vec4_one, vec4_one, vec4_one}).x == 2.0f); nihstro::SwizzlePattern swizzle = {};
// } swizzle.dest_mask = 0b1111;
swizzle.SetSelectorSrc1(0, SwizzlePattern::Selector::x);
swizzle.SetSelectorSrc1(1, SwizzlePattern::Selector::y);
swizzle.SetSelectorSrc1(2, SwizzlePattern::Selector::z);
swizzle.SetSelectorSrc1(3, SwizzlePattern::Selector::w);
swizzle.SetSelectorSrc2(0, SwizzlePattern::Selector::x);
swizzle.SetSelectorSrc2(1, SwizzlePattern::Selector::y);
swizzle.SetSelectorSrc2(2, SwizzlePattern::Selector::z);
swizzle.SetSelectorSrc2(3, SwizzlePattern::Selector::w);
swizzle.SetSelectorSrc3(0, SwizzlePattern::Selector::x);
swizzle.SetSelectorSrc3(1, SwizzlePattern::Selector::y);
swizzle.SetSelectorSrc3(2, SwizzlePattern::Selector::z);
swizzle.SetSelectorSrc3(3, SwizzlePattern::Selector::w);
shader_setup->swizzle_data[0] = swizzle.hex;
auto shader = ShaderTest(std::move(shader_setup));
REQUIRE(shader.Run({vec4_zero, vec4_zero, vec4_zero}) == vec4_zero);
REQUIRE(shader.Run({vec4_one, vec4_one, vec4_one}) == (vec4_one * 2.0f));
REQUIRE(shader.Run({vec4_inf, vec4_zero, vec4_zero}) == vec4_zero);
REQUIRE(shader.Run({vec4_nan, vec4_zero, vec4_zero}) == vec4_nan);
}
TEST_CASE("Nested Loop", "[video_core][shader][shader_jit]") { TEST_CASE("Nested Loop", "[video_core][shader][shader_jit]") {
const auto sh_input = SourceRegister::MakeInput(0); const auto sh_input = SourceRegister::MakeInput(0);
@ -518,4 +667,42 @@ TEST_CASE("Nested Loop", "[video_core][shader][shader_jit]") {
} }
} }
#endif // CITRA_ARCH(x86_64) TEST_CASE("Source Swizzle", "[video_core][shader][shader_jit]") {
const auto sh_input = SourceRegister::MakeInput(0);
const auto sh_output = DestRegister::MakeOutput(0);
const auto shader = [&sh_input, &sh_output](const char* swizzle) {
return std::unique_ptr<ShaderTest>(new ShaderTest{
{OpCode::Id::MOV, sh_output, "xyzw", sh_input, swizzle, SourceRegister{}, ""},
{OpCode::Id::END},
});
};
const Common::Vec4f iota_vec = {1.0f, 2.0f, 3.0f, 4.0f};
REQUIRE(shader("x")->Run({iota_vec}).x == iota_vec.x);
REQUIRE(shader("y")->Run({iota_vec}).x == iota_vec.y);
REQUIRE(shader("z")->Run({iota_vec}).x == iota_vec.z);
REQUIRE(shader("w")->Run({iota_vec}).x == iota_vec.w);
REQUIRE(shader("xy")->Run({iota_vec}).xy() == iota_vec.xy());
REQUIRE(shader("xz")->Run({iota_vec}).xy() == iota_vec.xz());
REQUIRE(shader("xw")->Run({iota_vec}).xy() == iota_vec.xw());
REQUIRE(shader("yz")->Run({iota_vec}).xy() == iota_vec.yz());
REQUIRE(shader("yw")->Run({iota_vec}).xy() == iota_vec.yw());
REQUIRE(shader("zw")->Run({iota_vec}).xy() == iota_vec.zw());
REQUIRE(shader("yy")->Run({iota_vec}).xy() == iota_vec.yy());
REQUIRE(shader("wx")->Run({iota_vec}).xy() == iota_vec.wx());
REQUIRE(shader("xyz")->Run({iota_vec}).xyz() == iota_vec.xyz());
REQUIRE(shader("xyw")->Run({iota_vec}).xyz() == iota_vec.xyw());
REQUIRE(shader("xzw")->Run({iota_vec}).xyz() == iota_vec.xzw());
REQUIRE(shader("yzw")->Run({iota_vec}).xyz() == iota_vec.yzw());
REQUIRE(shader("yyy")->Run({iota_vec}).xyz() == iota_vec.yyy());
REQUIRE(shader("yxw")->Run({iota_vec}).xyz() == iota_vec.yxw());
REQUIRE(shader("xyzw")->Run({iota_vec}) == iota_vec);
REQUIRE(shader("wzxy")->Run({iota_vec}) ==
Common::Vec4f(iota_vec.w, iota_vec.z, iota_vec.x, iota_vec.y));
REQUIRE(shader("yyyy")->Run({iota_vec}) ==
Common::Vec4f(iota_vec.y, iota_vec.y, iota_vec.y, iota_vec.y));
}
#endif // CITRA_ARCH(x86_64) || CITRA_ARCH(arm64)

View File

@ -149,6 +149,10 @@ add_library(video_core STATIC
shader/shader.h shader/shader.h
shader/shader_interpreter.cpp shader/shader_interpreter.cpp
shader/shader_interpreter.h shader/shader_interpreter.h
shader/shader_jit_a64.cpp
shader/shader_jit_a64_compiler.cpp
shader/shader_jit_a64.h
shader/shader_jit_a64_compiler.h
shader/shader_jit_x64.cpp shader/shader_jit_x64.cpp
shader/shader_jit_x64_compiler.cpp shader/shader_jit_x64_compiler.cpp
shader/shader_jit_x64.h shader/shader_jit_x64.h
@ -177,6 +181,10 @@ if ("x86_64" IN_LIST ARCHITECTURE)
target_link_libraries(video_core PUBLIC xbyak) target_link_libraries(video_core PUBLIC xbyak)
endif() endif()
if ("arm64" IN_LIST ARCHITECTURE)
target_link_libraries(video_core PUBLIC oaknut)
endif()
if (CITRA_USE_PRECOMPILED_HEADERS) if (CITRA_USE_PRECOMPILED_HEADERS)
target_precompile_headers(video_core PRIVATE precompiled_headers.h) target_precompile_headers(video_core PRIVATE precompiled_headers.h)
endif() endif()

View File

@ -15,7 +15,9 @@
#include "video_core/shader/shader_interpreter.h" #include "video_core/shader/shader_interpreter.h"
#if CITRA_ARCH(x86_64) #if CITRA_ARCH(x86_64)
#include "video_core/shader/shader_jit_x64.h" #include "video_core/shader/shader_jit_x64.h"
#endif // CITRA_ARCH(x86_64) #elif CITRA_ARCH(arm64)
#include "video_core/shader/shader_jit_a64.h"
#endif
#include "video_core/video_core.h" #include "video_core/video_core.h"
namespace Pica::Shader { namespace Pica::Shader {
@ -141,27 +143,29 @@ MICROPROFILE_DEFINE(GPU_Shader, "GPU", "Shader", MP_RGB(50, 50, 240));
#if CITRA_ARCH(x86_64) #if CITRA_ARCH(x86_64)
static std::unique_ptr<JitX64Engine> jit_engine; static std::unique_ptr<JitX64Engine> jit_engine;
#endif // CITRA_ARCH(x86_64) #elif CITRA_ARCH(arm64)
static std::unique_ptr<JitA64Engine> jit_engine;
#endif
static InterpreterEngine interpreter_engine; static InterpreterEngine interpreter_engine;
ShaderEngine* GetEngine() { ShaderEngine* GetEngine() {
#if CITRA_ARCH(x86_64) #if CITRA_ARCH(x86_64) || CITRA_ARCH(arm64)
// TODO(yuriks): Re-initialize on each change rather than being persistent // TODO(yuriks): Re-initialize on each change rather than being persistent
if (VideoCore::g_shader_jit_enabled) { if (VideoCore::g_shader_jit_enabled) {
if (jit_engine == nullptr) { if (jit_engine == nullptr) {
jit_engine = std::make_unique<JitX64Engine>(); jit_engine = std::make_unique<decltype(jit_engine)::element_type>();
} }
return jit_engine.get(); return jit_engine.get();
} }
#endif // CITRA_ARCH(x86_64) #endif // CITRA_ARCH(x86_64) || CITRA_ARCH(arm64)
return &interpreter_engine; return &interpreter_engine;
} }
void Shutdown() { void Shutdown() {
#if CITRA_ARCH(x86_64) #if CITRA_ARCH(x86_64) || CITRA_ARCH(arm64)
jit_engine = nullptr; jit_engine = nullptr;
#endif // CITRA_ARCH(x86_64) #endif // CITRA_ARCH(x86_64) || CITRA_ARCH(arm64)
} }
} // namespace Pica::Shader } // namespace Pica::Shader

View File

@ -0,0 +1,51 @@
// Copyright 2023 Citra Emulator Project
// Licensed under GPLv2 or any later version
// Refer to the license.txt file included.
#include "common/arch.h"
#if CITRA_ARCH(arm64)
#include "common/assert.h"
#include "common/microprofile.h"
#include "video_core/shader/shader.h"
#include "video_core/shader/shader_jit_a64.h"
#include "video_core/shader/shader_jit_a64_compiler.h"
namespace Pica::Shader {
JitA64Engine::JitA64Engine() = default;
JitA64Engine::~JitA64Engine() = default;
void JitA64Engine::SetupBatch(ShaderSetup& setup, unsigned int entry_point) {
ASSERT(entry_point < MAX_PROGRAM_CODE_LENGTH);
setup.engine_data.entry_point = entry_point;
u64 code_hash = setup.GetProgramCodeHash();
u64 swizzle_hash = setup.GetSwizzleDataHash();
u64 cache_key = code_hash ^ swizzle_hash;
auto iter = cache.find(cache_key);
if (iter != cache.end()) {
setup.engine_data.cached_shader = iter->second.get();
} else {
auto shader = std::make_unique<JitShader>();
shader->Compile(&setup.program_code, &setup.swizzle_data);
setup.engine_data.cached_shader = shader.get();
cache.emplace_hint(iter, cache_key, std::move(shader));
}
}
MICROPROFILE_DECLARE(GPU_Shader);
void JitA64Engine::Run(const ShaderSetup& setup, UnitState& state) const {
ASSERT(setup.engine_data.cached_shader != nullptr);
MICROPROFILE_SCOPE(GPU_Shader);
const JitShader* shader = static_cast<const JitShader*>(setup.engine_data.cached_shader);
shader->Run(setup, state, setup.engine_data.entry_point);
}
} // namespace Pica::Shader
#endif // CITRA_ARCH(arm64)

View File

@ -0,0 +1,33 @@
// Copyright 2023 Citra Emulator Project
// Licensed under GPLv2 or any later version
// Refer to the license.txt file included.
#pragma once
#include "common/arch.h"
#if CITRA_ARCH(arm64)
#include <memory>
#include <unordered_map>
#include "common/common_types.h"
#include "video_core/shader/shader.h"
namespace Pica::Shader {
class JitShader;
class JitA64Engine final : public ShaderEngine {
public:
JitA64Engine();
~JitA64Engine() override;
void SetupBatch(ShaderSetup& setup, unsigned int entry_point) override;
void Run(const ShaderSetup& setup, UnitState& state) const override;
private:
std::unordered_map<u64, std::unique_ptr<JitShader>> cache;
};
} // namespace Pica::Shader
#endif // CITRA_ARCH(arm64)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,146 @@
// Copyright 2023 Citra Emulator Project
// Licensed under GPLv2 or any later version
// Refer to the license.txt file included.
#pragma once
#include "common/arch.h"
#if CITRA_ARCH(arm64)
#include <array>
#include <bitset>
#include <cstddef>
#include <optional>
#include <utility>
#include <vector>
#include <nihstro/shader_bytecode.h>
#include <oaknut/code_block.hpp>
#include <oaknut/oaknut.hpp>
#include "common/common_types.h"
#include "video_core/shader/shader.h"
using nihstro::Instruction;
using nihstro::OpCode;
using nihstro::SourceRegister;
using nihstro::SwizzlePattern;
namespace Pica::Shader {
/// Memory allocated for each compiled shader
constexpr std::size_t MAX_SHADER_SIZE = MAX_PROGRAM_CODE_LENGTH * 256;
/**
* This class implements the shader JIT compiler. It recompiles a Pica shader program into x86_64
* code that can be executed on the host machine directly.
*/
class JitShader : private oaknut::CodeBlock, public oaknut::CodeGenerator {
public:
JitShader();
void Run(const ShaderSetup& setup, UnitState& state, unsigned offset) const {
program(&setup.uniforms, &state, instruction_labels[offset].ptr<const std::byte*>());
}
void Compile(const std::array<u32, MAX_PROGRAM_CODE_LENGTH>* program_code,
const std::array<u32, MAX_SWIZZLE_DATA_LENGTH>* swizzle_data);
void Compile_ADD(Instruction instr);
void Compile_DP3(Instruction instr);
void Compile_DP4(Instruction instr);
void Compile_DPH(Instruction instr);
void Compile_EX2(Instruction instr);
void Compile_LG2(Instruction instr);
void Compile_MUL(Instruction instr);
void Compile_SGE(Instruction instr);
void Compile_SLT(Instruction instr);
void Compile_FLR(Instruction instr);
void Compile_MAX(Instruction instr);
void Compile_MIN(Instruction instr);
void Compile_RCP(Instruction instr);
void Compile_RSQ(Instruction instr);
void Compile_MOVA(Instruction instr);
void Compile_MOV(Instruction instr);
void Compile_NOP(Instruction instr);
void Compile_END(Instruction instr);
void Compile_BREAKC(Instruction instr);
void Compile_CALL(Instruction instr);
void Compile_CALLC(Instruction instr);
void Compile_CALLU(Instruction instr);
void Compile_IF(Instruction instr);
void Compile_LOOP(Instruction instr);
void Compile_JMP(Instruction instr);
void Compile_CMP(Instruction instr);
void Compile_MAD(Instruction instr);
void Compile_EMIT(Instruction instr);
void Compile_SETE(Instruction instr);
private:
void Compile_Block(unsigned end);
void Compile_NextInstr();
void Compile_SwizzleSrc(Instruction instr, unsigned src_num, SourceRegister src_reg,
oaknut::QReg dest);
void Compile_DestEnable(Instruction instr, oaknut::QReg dest);
/**
* Compiles a `MUL src1, src2` operation, properly handling the PICA semantics when multiplying
* zero by inf. Clobbers `src2` and `scratch`.
*/
void Compile_SanitizedMul(oaknut::QReg src1, oaknut::QReg src2, oaknut::QReg scratch0);
void Compile_EvaluateCondition(Instruction instr);
void Compile_UniformCondition(Instruction instr);
/**
* Emits the code to conditionally return from a subroutine envoked by the `CALL` instruction.
*/
void Compile_Return();
std::bitset<64> PersistentCallerSavedRegs();
/**
* Assertion evaluated at compile-time, but only triggered if executed at runtime.
* @param condition Condition to be evaluated.
* @param msg Message to be logged if the assertion fails.
*/
void Compile_Assert(bool condition, const char* msg);
/**
* Analyzes the entire shader program for `CALL` instructions before emitting any code,
* identifying the locations where a return needs to be inserted.
*/
void FindReturnOffsets();
/**
* Emits data and code for utility functions.
*/
void CompilePrelude();
oaknut::Label CompilePrelude_Log2();
oaknut::Label CompilePrelude_Exp2();
const std::array<u32, MAX_PROGRAM_CODE_LENGTH>* program_code = nullptr;
const std::array<u32, MAX_SWIZZLE_DATA_LENGTH>* swizzle_data = nullptr;
/// Mapping of Pica VS instructions to pointers in the emitted code
std::array<oaknut::Label, MAX_PROGRAM_CODE_LENGTH> instruction_labels;
/// Labels pointing to the end of each nested LOOP block. Used by the BREAKC instruction to
/// break out of a loop.
std::vector<oaknut::Label> loop_break_labels;
/// Offsets in code where a return needs to be inserted
std::vector<unsigned> return_offsets;
unsigned program_counter = 0; ///< Offset of the next instruction to decode
u8 loop_depth = 0; ///< Depth of the (nested) loops currently compiled
using CompiledShader = void(const void* setup, void* state, const std::byte* start_addr);
CompiledShader* program = nullptr;
oaknut::Label log2_subroutine;
oaknut::Label exp2_subroutine;
};
} // namespace Pica::Shader
#endif