Commit 797594a0 authored by Jakub Kuderski's avatar Jakub Kuderski
Browse files

[mlir][spirv] Fix nullptr dereference in UnifyAliasedResource

Fixes: https://github.com/llvm/llvm-project/issues/62368

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D149376
parent d636bcb6
Loading
Loading
Loading
Loading
+7 −4
Original line number Diff line number Diff line
@@ -220,6 +220,9 @@ ResourceAliasAnalysis::ResourceAliasAnalysis(Operation *root) {
}

bool ResourceAliasAnalysis::shouldUnify(Operation *op) const {
  if (!op)
    return false;

  if (auto varOp = dyn_cast<spirv::GlobalVariableOp>(op)) {
    auto canonicalOp = getCanonicalResource(varOp);
    return canonicalOp && varOp != canonicalOp;
@@ -566,16 +569,15 @@ public:
private:
  spirv::GetTargetEnvFn getTargetEnvFn;
};
} // namespace

void UnifyAliasedResourcePass::runOnOperation() {
  spirv::ModuleOp moduleOp = getOperation();
  MLIRContext *context = &getContext();

  if (getTargetEnvFn) {
    // This pass is only needed for targeting WebGPU, Metal, or layering Vulkan
    // on Metal via MoltenVK, where we need to translate SPIR-V into WGSL or
    // MSL. The translation has limitations.
    // This pass is only needed for targeting WebGPU, Metal, or layering
    // Vulkan on Metal via MoltenVK, where we need to translate SPIR-V into
    // WGSL or MSL. The translation has limitations.
    spirv::TargetEnvAttr targetEnv = getTargetEnvFn(moduleOp);
    spirv::ClientAPI clientAPI = targetEnv.getClientAPI();
    bool isVulkanOnAppleDevices =
@@ -614,6 +616,7 @@ void UnifyAliasedResourcePass::runOnOperation() {
      resources.front()->removeAttr("aliased");
  }
}
} // namespace

std::unique_ptr<mlir::OperationPass<spirv::ModuleOp>>
spirv::createUnifyAliasedResourcePass(spirv::GetTargetEnvFn getTargetEnv) {
+16 −0
Original line number Diff line number Diff line
@@ -506,3 +506,19 @@ spirv.module Logical GLSL450 {
// CHECK:   %[[CC:.+]] = spirv.CompositeConstruct %[[BC0]], %[[BC1]] : (vector<2xf32>, vector<2xf32>) -> vector<4xf32>
// CHECK:   spirv.ReturnValue %[[CC]]

// -----

// Make sure we do not crash on function arguments.

spirv.module Logical GLSL450 {
  spirv.func @main(%arg0: !spirv.ptr<!spirv.struct<(!spirv.rtarray<f32, stride=4> [0])>, StorageBuffer>) "None" {
    %cst0_i32 = spirv.Constant 0 : i32
    %0 = spirv.AccessChain %arg0[%cst0_i32, %cst0_i32] : !spirv.ptr<!spirv.struct<(!spirv.rtarray<f32, stride=4> [0])>, StorageBuffer>, i32, i32
    spirv.Return
  }
}

// CHECK-LABEL: spirv.module
// CHECK-LABEL: spirv.func @main
// CHECK-SAME:  (%{{.+}}: !spirv.ptr<!spirv.struct<(!spirv.rtarray<f32, stride=4> [0])>, StorageBuffer>) "None"
// CHECK:       spirv.Return