Skip to content

Commit

Permalink
Revert "Use shader register space for root constant arguments specifi…
Browse files Browse the repository at this point in the history
…cation"

Solution is incompatible with Metal argument buffers, which are created one per each register space.

This reverts commit 1e97c1f.
  • Loading branch information
egorodet committed Sep 22, 2024
1 parent 1e97c1f commit 38cd321
Show file tree
Hide file tree
Showing 14 changed files with 86 additions and 60 deletions.
11 changes: 10 additions & 1 deletion Apps/02-HelloCube/HelloCubeApp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,16 @@ class HelloCubeApp final // NOSONAR - destructor required
Rhi::ProgramInputBufferLayout::ArgumentSemantics{ "POSITION" , "COLOR" }
}
},
Rhi::ProgramArgumentAccessors{ },
Rhi::ProgramArgumentAccessors
{
#ifdef UNIFORMS_ENABLED
{ // Uniforms argument is declared as root constant
{ Rhi::ShaderType::Vertex, "g_uniforms" },
Rhi::ProgramArgumentAccessType::FrameConstant,
Rhi::ProgramArgumentValueType::RootConstant
}
#endif
},
GetScreenRenderPattern().GetAttachmentFormats()
}
),
Expand Down
2 changes: 1 addition & 1 deletion Apps/02-HelloCube/Shaders/HelloCube.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ struct PSInput
};

#ifdef UNIFORMS_ENABLED
ConstantBuffer<Uniforms> g_uniforms : register(b0, META_ARG_ROOT_FRAME_CONSTANT);
ConstantBuffer<Uniforms> g_uniforms : register(b0, META_ARG_FRAME_CONSTANT);
#endif

PSInput CubeVS(VSInput input)
Expand Down
4 changes: 2 additions & 2 deletions Apps/03-TexturedCube/Shaders/TexturedCube.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ struct PSInput
float2 texcoord : TEXCOORD;
};

ConstantBuffer<Constants> g_constants : register(b0, META_ARG_ROOT_CONSTANT);
ConstantBuffer<Uniforms> g_uniforms : register(b1, META_ARG_ROOT_FRAME_CONSTANT);
ConstantBuffer<Constants> g_constants : register(b0, META_ARG_CONSTANT);
ConstantBuffer<Uniforms> g_uniforms : register(b1, META_ARG_FRAME_CONSTANT);
Texture2D g_texture : register(t0, META_ARG_CONSTANT);
SamplerState g_sampler : register(s0, META_ARG_CONSTANT);

Expand Down
14 changes: 13 additions & 1 deletion Apps/03-TexturedCube/TexturedCubeApp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,19 @@ void TexturedCubeApp::Init()
rhi::Program::InputBufferLayout::ArgumentSemantics { cube_mesh.GetVertexLayout().GetSemantics() }
}
},
rhi::ProgramArgumentAccessors{ },
rhi::ProgramArgumentAccessors
{ // Define arguments as root constants
{
{ rhi::ShaderType::Pixel, "g_constants" },
rhi::ProgramArgumentAccessType::Constant,
rhi::ProgramArgumentValueType::RootConstant
},
{
{ rhi::ShaderType::All, "g_uniforms" },
rhi::ProgramArgumentAccessType::FrameConstant,
rhi::ProgramArgumentValueType::RootConstant
}
},
GetScreenRenderPattern().GetAttachmentFormats()
}
),
Expand Down
6 changes: 3 additions & 3 deletions Apps/04-ShadowCube/Shaders/ShadowCube.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,9 @@ struct PSInput
#endif
};

ConstantBuffer<Constants> g_constants : register(b0, META_ARG_ROOT_CONSTANT);
ConstantBuffer<SceneUniforms> g_scene_uniforms : register(b1, META_ARG_ROOT_FRAME_CONSTANT);
ConstantBuffer<MeshUniforms> g_mesh_uniforms : register(b2, META_ARG_ROOT_MUTABLE);
ConstantBuffer<Constants> g_constants : register(b0, META_ARG_CONSTANT);
ConstantBuffer<SceneUniforms> g_scene_uniforms : register(b1, META_ARG_FRAME_CONSTANT);
ConstantBuffer<MeshUniforms> g_mesh_uniforms : register(b2, META_ARG_MUTABLE);

#ifdef ENABLE_SHADOWS
Texture2D g_shadow_map : register(t0, META_ARG_FRAME_CONSTANT);
Expand Down
19 changes: 18 additions & 1 deletion Apps/04-ShadowCube/ShadowCubeApp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,24 @@ void ShadowCubeApp::Init()
rhi::Program::InputBufferLayout::ArgumentSemantics { cube_mesh.GetVertexLayout().GetSemantics() }
}
},
rhi::ProgramArgumentAccessors{ },
rhi::ProgramArgumentAccessors
{
{
{ rhi::ShaderType::Pixel, "g_constants" },
rhi::ProgramArgumentAccessType::Constant,
rhi::ProgramArgumentValueType::RootConstant
},
{
{ rhi::ShaderType::Pixel, "g_scene_uniforms" },
rhi::ProgramArgumentAccessType::FrameConstant,
rhi::ProgramArgumentValueType::RootConstant
},
{
{ rhi::ShaderType::Vertex, "g_mesh_uniforms" },
rhi::ProgramArgumentAccessType::Mutable,
rhi::ProgramArgumentValueType::RootConstant
}
},
GetScreenRenderPattern().GetAttachmentFormats()
}
),
Expand Down
7 changes: 2 additions & 5 deletions CMake/MethaneShaders.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,8 @@ endfunction()
function(get_hlsl_compile_definitions OUT_HLSL_COMPILE_DEFINITIONS)
set(${OUT_HLSL_COMPILE_DEFINITIONS}
-D META_ARG_CONSTANT=space0
-D META_ARG_ROOT_CONSTANT=space1
-D META_ARG_FRAME_CONSTANT=space2
-D META_ARG_ROOT_FRAME_CONSTANT=space3
-D META_ARG_MUTABLE=space4
-D META_ARG_ROOT_MUTABLE=space5
-D META_ARG_FRAME_CONSTANT=space1
-D META_ARG_MUTABLE=space2
PARENT_SCOPE)
endfunction()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -200,12 +200,11 @@ Ptrs<Base::ProgramArgumentBinding> Shader::GetArgumentBindings(const Rhi::Progra
D3D12_SHADER_INPUT_BIND_DESC binding_desc{};
ThrowIfFailed(m_reflection_cptr->GetResourceBindingDesc(resource_index, &binding_desc));

const Rhi::ProgramArgumentAccessor::Types arg_types = Rhi::ProgramArgumentAccessor::GetTypeByRegisterSpace(binding_desc.Space);
const Rhi::ProgramArgumentAccessType arg_access_type = Rhi::ProgramArgumentAccessor::GetTypeByRegisterSpace(binding_desc.Space);
const Rhi::ProgramArgument shader_argument(GetType(), Base::Shader::GetCachedArgName(binding_desc.Name));
const Rhi::ProgramArgumentAccessor* argument_ptr = Rhi::IProgram::FindArgumentAccessor(argument_accessors, shader_argument);
const Rhi::ProgramArgumentAccessor argument_acc = argument_ptr
? *argument_ptr
: Rhi::ProgramArgumentAccessor(shader_argument, arg_types.first, arg_types.second);
const Rhi::ProgramArgumentAccessor argument_acc = argument_ptr ? *argument_ptr
: Rhi::ProgramArgumentAccessor(shader_argument, arg_access_type);
const D3D12_SHADER_BUFFER_DESC buffer_desc = GetConstantBufferDesc(*m_reflection_cptr.Get(), binding_desc);

ProgramArgumentBindingType dx_binding_type = ProgramArgumentBindingType::DescriptorTable;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,14 +90,12 @@ class ProgramArgumentNotFoundException : public std::invalid_argument
};

// NOTE: Access Type enum values should strictly match with
// register space value divided by two of 'META_ARG_*' shader definitions from MethaneShaders.cmake.
// EVEN space values are related to resource views, while ODD space values relate to root constants
// register space values of 'META_ARG_*' shader definitions from MethaneShaders.cmake:
enum class ProgramArgumentAccessType : uint32_t
{
// EVEN: Resource View ODD: Root Constant
Constant, // META_ARG_CONSTANT(0), META_ARG_ROOT_CONSTANT(1)
FrameConstant, // META_ARG_FRAME_CONSTANT(2), META_ARG_ROOT_FRAME_CONSTANT(3)
Mutable // META_ARG_MUTABLE(4), META_ARG_ROOT_MUTABLE(5)
Constant, // META_ARG_CONSTANT(0)
FrameConstant, // META_ARG_FRAME_CONSTANT(1)
Mutable // META_ARG_MUTABLE(2)
};

using ProgramArgumentAccessMask = Data::EnumMask<ProgramArgumentAccessType>;
Expand All @@ -114,12 +112,11 @@ using ProgramArguments = std::unordered_set<ProgramArgument, ProgramArgument::Ha
class ProgramArgumentAccessor : public ProgramArgument
{
public:
using Mask = ProgramArgumentAccessMask;
using Type = ProgramArgumentAccessType;
using Mask = ProgramArgumentAccessMask;
using ValueType = ProgramArgumentValueType;
using Types = std::pair<Type, ValueType>;

static Types GetTypeByRegisterSpace(uint32_t register_space);
static Type GetTypeByRegisterSpace(uint32_t register_space);

ProgramArgumentAccessor(ShaderType shader_type,
std::string_view arg_name,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,16 +74,12 @@ void ProgramArgument::MergeShaderTypes(ShaderType shader_type)
m_hash = GetProgramArgumentHash(m_shader_type, m_name);
}

ProgramArgumentAccessor::Types ProgramArgumentAccessor::GetTypeByRegisterSpace(uint32_t register_space)
ProgramArgumentAccessor::Type ProgramArgumentAccessor::GetTypeByRegisterSpace(uint32_t register_space)
{
META_FUNCTION_TASK();
META_CHECK_LESS_DESCR(register_space, magic_enum::enum_count<Type>() * 2U,
"shader register space is out of values range for Rhi::ProgramArgumentAccessType enum");
const auto arg_type = static_cast<Type>(register_space / 2);
const auto val_type = register_space % 2
? ValueType::RootConstant
: ValueType::ResourceView;
return std::make_pair(arg_type, val_type);
META_CHECK_LESS_DESCR(register_space, magic_enum::enum_count<ProgramArgumentAccessor::Type>(),
"shader register space is out of values range for Rhi::ProgramArgumentAccessType enum");
return static_cast<ProgramArgumentAccessor::Type>(register_space);
}

ProgramArgumentAccessor::ProgramArgumentAccessor(ShaderType shader_type, std::string_view arg_name,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ private:
// Base::Program overrides
void InitArgumentBindings() override;

using ArgumentBufferSizeByAccessType = std::array<Data::Size, magic_enum::enum_count<Rhi::ProgramArgumentAccessType>() * 2>;
using ArgumentBufferSizeByAccessType = std::array<Data::Size, magic_enum::enum_count<Rhi::ProgramArgumentAccessType>()>;

MTLVertexDescriptor* m_mtl_vertex_desc = nil;
Data::Index m_start_vertex_buffer_index = 0U;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -189,21 +189,21 @@
Data::Index arg_buffer_index = 0U;
for(const ArgumentBufferLayout& arg_buffer_layout : metal_shader.GetArgumentBufferLayouts())
{
// NOTE: see comment for ProgramArgumentAccessType why index is divided by two to get access index:
const Data::Index arg_acess_index = arg_buffer_index / 2;
META_CHECK_LESS(arg_acess_index, m_arguments_range_size_by_access_type.size());
arg_buffer_index++;

META_CHECK_LESS(arg_buffer_index, m_arguments_range_size_by_access_type.size());
if (!arg_buffer_layout.data_size)
{
arg_buffer_index++;
continue;
}

m_shader_argument_buffer_layouts.push_back({
shader.GetType(),
arg_buffer_layout.data_size,
static_cast<Rhi::ProgramArgumentAccessType>(arg_acess_index)
static_cast<Rhi::ProgramArgumentAccessType>(arg_buffer_index)
});

m_arguments_range_size_by_access_type[arg_acess_index] += arg_buffer_layout.data_size;
m_arguments_range_size_by_access_type[arg_buffer_index] += arg_buffer_layout.data_size;
arg_buffer_index++;
}
});

Expand Down
27 changes: 13 additions & 14 deletions Modules/Graphics/RHI/Metal/Sources/Methane/Graphics/Metal/Shader.mm
Original file line number Diff line number Diff line change
Expand Up @@ -169,13 +169,12 @@ static uint32_t GetBufferSizeOfStructMember(MTLStructMember* mtl_struct_member)
}

[[nodiscard]]
static bool IsArgumentBufferName(std::string_view argument_name, Rhi::ProgramArgumentAccessor::Types& arg_types)
static bool IsArgumentBufferName(std::string_view argument_name, Rhi::ProgramArgumentAccessType& arg_access_type)
{
META_FUNCTION_TASK();
if (argument_name == "top_level_global_ab")
{
arg_types.first = Rhi::ProgramArgumentAccessType::Mutable;
arg_types.second = Rhi::ProgramArgumentValueType::ResourceView;
arg_access_type = Rhi::ProgramArgumentAccessType::Mutable;
return true;
}

Expand All @@ -187,7 +186,7 @@ static bool IsArgumentBufferName(std::string_view argument_name, Rhi::ProgramArg
id_match.size() == 2)
{
const uint32_t register_space = std::stoi(id_match[1]);
arg_types = Rhi::ProgramArgumentAccessor::GetTypeByRegisterSpace(register_space);
arg_access_type = Rhi::ProgramArgumentAccessor::GetTypeByRegisterSpace(register_space);
return true;
}

Expand Down Expand Up @@ -311,7 +310,7 @@ static bool IsArgumentBufferName(std::string_view argument_name, Rhi::ProgramArg

const auto add_argument_binding = [this, &argument_bindings, &argument_accessors]
(const std::string& argument_name,
Rhi::ProgramArgumentAccessor::Types arg_types,
Rhi::ProgramArgumentAccessType argument_access,
Rhi::ResourceType resource_type,
uint32_t array_length,
uint32_t buffer_size,
Expand All @@ -322,7 +321,7 @@ static bool IsArgumentBufferName(std::string_view argument_name, Rhi::ProgramArg
const Rhi::ProgramArgumentAccessor* argument_accessor_ptr = Rhi::IProgram::FindArgumentAccessor(argument_accessors, shader_argument);
const Rhi::ProgramArgumentAccessor& argument_accessor = argument_accessor_ptr
? *argument_accessor_ptr
: Rhi::ProgramArgumentAccessor(shader_argument, arg_types.first, arg_types.second);
: Rhi::ProgramArgumentAccessor(shader_argument, argument_access);

ProgramArgumentBindingSettings::StructOffsetByShaderType argument_buffer_offset_by_shader_type;
if (argument_buffer_offset_opt)
Expand Down Expand Up @@ -361,17 +360,17 @@ static bool IsArgumentBufferName(std::string_view argument_name, Rhi::ProgramArg
}

const auto argument_index = static_cast<uint32_t>(mtl_binding.index);
Rhi::ProgramArgumentAccessor::Types arg_types = { Rhi::ProgramArgumentAccessType::Mutable, Rhi::ProgramArgumentValueType::ResourceView };
if (mtl_binding.type == MTLBindingTypeBuffer && IsArgumentBufferName(argument_name, arg_types))
Rhi::ProgramArgumentAccessType arg_access_type = Rhi::ProgramArgumentAccessType::Mutable;
if (mtl_binding.type == MTLBindingTypeBuffer && IsArgumentBufferName(argument_name, arg_access_type))
{
// Get arguments from argument buffer layout
META_CHECK_LESS_DESCR(argument_index, m_argument_buffer_layouts.size(),
"inconsistent argument buffer layouts");
"inconsistent argument buffer layouts");
for(const auto& [name, member] : m_argument_buffer_layouts[argument_index].member_by_name)
{
add_argument_binding(
name,
arg_types,
arg_access_type,
member.resource_type,
member.array_size,
member.buffer_size,
Expand All @@ -384,7 +383,7 @@ static bool IsArgumentBufferName(std::string_view argument_name, Rhi::ProgramArg
{
add_argument_binding(
argument_name,
arg_types,
arg_access_type,
GetResourceTypeByMetalBindingType(mtl_binding.type),
GetBindingArrayLength(mtl_binding),
GetBindingBufferSize(mtl_binding),
Expand All @@ -405,11 +404,11 @@ static bool IsArgumentBufferName(std::string_view argument_name, Rhi::ProgramArg
// Fill argument buffer layouts
for(id<MTLBinding> mtl_binding in m_mtl_bindings)
{
Rhi::ProgramArgumentAccessor::Types arg_types = { Rhi::ProgramArgumentAccessType::Mutable, Rhi::ProgramArgumentValueType::ResourceView };
const std::string arg_name = MacOS::ConvertFromNsString(mtl_binding.name);
Rhi::ProgramArgumentAccessType arg_access_type = Rhi::ProgramArgumentAccessType::Mutable;
const std::string argument_name = MacOS::ConvertFromNsString(mtl_binding.name);
if (mtl_binding.argument && mtl_binding.used &&
mtl_binding.type == MTLBindingTypeBuffer &&
IsArgumentBufferName(arg_name, arg_types))
IsArgumentBufferName(argument_name, arg_access_type))
{
// Argument buffer structure
const auto argument_index = static_cast<uint32_t>(mtl_binding.index);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -174,13 +174,13 @@ static void AddSpirvResourcesToArgumentBindings(const spirv_cross::Compiler& spi
META_CHECK_TRUE(spirv_compiler.get_binary_offset_for_decoration(resource.id, spv::DecorationBinding, byte_code_map.binding_offset));

const uint32_t descriptor_set_id = spirv_compiler.get_decoration(resource.id, spv::DecorationDescriptorSet);
const Rhi::ProgramArgumentAccessor::Types arg_types = Rhi::ProgramArgumentAccessor::GetTypeByRegisterSpace(descriptor_set_id);
const Rhi::ProgramArgumentAccessType arg_access_type = Rhi::ProgramArgumentAccessor::GetTypeByRegisterSpace(descriptor_set_id);

const Rhi::ProgramArgument shader_argument(shader_type, shader.GetCachedArgName(spirv_compiler.get_name(resource.id)));
const Rhi::ProgramArgumentAccessor* argument_accessor_ptr = Rhi::IProgram::FindArgumentAccessor(argument_accessors, shader_argument);
const Rhi::ProgramArgumentAccessor argument_acc = argument_accessor_ptr
? *argument_accessor_ptr
: Rhi::ProgramArgumentAccessor(shader_argument, arg_types.first, arg_types.second);
? *argument_accessor_ptr
: Rhi::ProgramArgumentAccessor(shader_argument, arg_access_type);

argument_bindings.push_back(std::make_shared<ProgramBindings::ArgumentBinding>(
shader.GetContext(),
Expand Down

0 comments on commit 38cd321

Please sign in to comment.