Commit 175e7002 authored by Stelle, George Widgery's avatar Stelle, George Widgery Committed by George Stelle
Browse files

Fixed indexing issue

parent f9335de2
Loading
Loading
Loading
Loading
+39 −44
Original line number Diff line number Diff line
@@ -24,9 +24,7 @@
#include "llvm/Transforms/Scalar/GVN.h"
#include "llvm/Transforms/Vectorize.h"
#include "llvm/Support/TargetRegistry.h"
#ifdef KITSUNE_ENABLE_OPENCL_RUNTIME
#include <LLVMSPIRVLib/LLVMSPIRVLib.h>
#endif
#include <sstream>

using namespace llvm;
@@ -234,7 +232,6 @@ void SPIRVLoop::postProcessOutline(TapirLoopInfo &TL, TaskOutlineInfo &Out,
  ClonedSyncReg->eraseFromParent();

  // Set the helper function to have external linkage.

  // Get the thread ID for this invocation of Helper.
  IRBuilder<> B(Entry->getTerminator());
  Value *ThreadIdx = B.CreateCall(GetThreadIdx, ConstantInt::get(Int32Ty, 0));
@@ -242,46 +239,11 @@ void SPIRVLoop::postProcessOutline(TapirLoopInfo &TL, TaskOutlineInfo &Out,
  //Value *BlockDim = B.CreateCall(GetBlockDim, ConstantInt::get(Int32Ty, 0));
  Value *ThreadID = B.CreateIntCast(ThreadIdx, PrimaryIV->getType(), false);


  Function *Helper = Out.Outline;
  //Helper->setName("kitsune_spirv_kernel"); 
  // Fix argument pointer types to global, nocapture
  // TODO: read/write attributes?
  SmallVector<Type*, 8> paramTys; 
  for(auto &arg : Helper->args()){
    if (auto *apty = dyn_cast<PointerType>(arg.getType())){
      paramTys.push_back(PointerType::get(apty->getPointerElementType(), 1)); 
    } else {
      paramTys.push_back(arg.getType()); 
    }
  }
  ArrayRef<Type*> newParams(paramTys); 
  if(auto *fpty = dyn_cast<PointerType>(Helper->getType())){
    if(auto *fty = dyn_cast<FunctionType>(fpty->getPointerElementType())){
      LLVM_DEBUG(dbgs() << "Helper is pointer to function " << *Helper->getType() << "\n"); 
      auto *NewHelper = Function::Create(
          FunctionType::get(fty->getReturnType(), newParams, false), 
          GlobalValue::ExternalLinkage, 
          "kitsune_spirv_kernel", 
          SPIRVM);

      ValueToValueMapTy VMap;
      auto argit = NewHelper->arg_begin();
      for (auto &arg : Helper->args()) {
        VMap[&arg] = argit++; 
      }
      SmallVector< ReturnInst *,5> retinsts;
      CloneFunctionInto(NewHelper, Helper, VMap, false, retinsts);
      //Helper->mutateType(PointerType::get(FunctionType::get(fty->getReturnType(), newParams, false), 0)); 
      NewHelper->setCallingConv(CallingConv::SPIR_KERNEL); 
      for(auto &arg : NewHelper->args()){
        if (auto *apty = dyn_cast<PointerType>(arg.getType())){
          arg.addAttr(Attribute::NoCapture);
        }
      }
      Helper = NewHelper; 
    }
  }

  LLVM_DEBUG(dbgs() << "Function type after globalization of argument pointers << " << *Helper->getType() << "\n"); 
  LLVM_DEBUG(dbgs() << "SPIRVM after globalization of argument pointers << " << *Helper->getParent() << "\n"); 

@@ -316,15 +278,50 @@ void SPIRVLoop::postProcessOutline(TapirLoopInfo &TL, TaskOutlineInfo &Out,
  PrimaryIVInput->replaceAllUsesWith(ThreadID);

  // Update cloned loop condition to use the thread-end value.
  /*
  unsigned TripCountIdx = 0;
  ICmpInst *ClonedCond = cast<ICmpInst>(VMap[TL.getCondition()]);
  if (ClonedCond->getOperand(0) != ThreadEnd)
    ++TripCountIdx;
  ClonedCond->setOperand(TripCountIdx, ThreadEnd);
  assert(ClonedCond->getOperand(TripCountIdx) == ThreadEnd &&
         "End argument not used in condition");
  ClonedCond->setOperand(TripCountIdx, ThreadEnd);
  */

  // Update paramaters with necessary address space modifcations
  SmallVector<Type*, 8> paramTys; 
  for(auto &arg : Helper->args()){
    if (auto *apty = dyn_cast<PointerType>(arg.getType())){
      paramTys.push_back(PointerType::get(apty->getPointerElementType(), 1)); 
    } else {
      paramTys.push_back(arg.getType()); 
    }
  }
  ArrayRef<Type*> newParams(paramTys); 
  if(auto *fpty = dyn_cast<PointerType>(Helper->getType())){
    if(auto *fty = dyn_cast<FunctionType>(fpty->getPointerElementType())){
      LLVM_DEBUG(dbgs() << "Helper is pointer to function " << *Helper->getType() << "\n"); 
      auto *NewHelper = Function::Create(
          FunctionType::get(fty->getReturnType(), newParams, false), 
          GlobalValue::ExternalLinkage, 
          "kitsune_spirv_kernel", 
          SPIRVM);

      ValueToValueMapTy NewVMap;
      auto argit = NewHelper->arg_begin();
      for (auto &arg : Helper->args()) {
        NewVMap[&arg] = argit++; 
      }
      SmallVector< ReturnInst *,5> retinsts;
      CloneFunctionInto(NewHelper, Helper, NewVMap, false, retinsts);
      //Helper->mutateType(PointerType::get(FunctionType::get(fty->getReturnType(), newParams, false), 0)); 
      NewHelper->setCallingConv(CallingConv::SPIR_KERNEL); 
      for(auto &arg : NewHelper->args()){
        if (auto *apty = dyn_cast<PointerType>(arg.getType())){
          arg.addAttr(Attribute::NoCapture);
        }
      }
      Helper = NewHelper; 
    }
  }

  fixAddressSpaces(Helper); 

@@ -384,7 +381,6 @@ void SPIRVLoop::processOutlinedLoopCall(TapirLoopInfo &TL, TaskOutlineInfo &TOI,

  LLVM_DEBUG(dbgs() << "SPIRV Module: " << SPIRVM);

#ifdef KITSUNE_ENABLE_OPENCL_RUNTIME

  // generate spirv kernel code
  std::ostringstream str; 
@@ -399,7 +395,6 @@ void SPIRVLoop::processOutlinedLoopCall(TapirLoopInfo &TL, TaskOutlineInfo &TOI,
  SPIRVGlobal = new GlobalVariable(M, SPIRV->getType(), true,
                                 GlobalValue::PrivateLinkage, SPIRV,
                                 "spirv_" + Twine("kitsune_spirv_kernel"));
#endif

  //Value* TripCount = isSRetInput(TOI.InputSet[0]) ? TOI.InputSet[1] : TOI.InputSet[0]; 
  //Value *RunStart = ReplCall->getArgOperand(getIVArgIndex(*Parent,