Commit f1342c74 authored by Muhammad Omair Javaid's avatar Muhammad Omair Javaid
Browse files

[LLDB] AArch64 SVE restore SVE registers after expression

This patch fixes register save/restore on expression call to also include SVE registers.

This will fix expression calls like:

re re p1

<Register Value P1 before expression>

p <var-name or function call>

re re p1

<Register Value P1 after expression>

In above example register P1 should remain the same before and after the expression evaluation.

Reviewed By: DavidSpickett

Differential Revision: https://reviews.llvm.org/D108739
parent fbb8b415
Loading
Loading
Loading
Loading
+115 −29
Original line number Diff line number Diff line
@@ -46,8 +46,6 @@
#define HWCAP_PACA (1 << 30)
#define HWCAP2_MTE (1 << 18)

#define REG_CONTEXT_SIZE (GetGPRSize() + GetFPRSize())

using namespace lldb;
using namespace lldb_private;
using namespace lldb_private::process_linux;
@@ -452,45 +450,80 @@ Status NativeRegisterContextLinux_arm64::WriteRegister(

Status NativeRegisterContextLinux_arm64::ReadAllRegisterValues(
    lldb::DataBufferSP &data_sp) {
  Status error;
  // AArch64 register data must contain GPRs, either FPR or SVE registers
  // and optional MTE register. Pointer Authentication (PAC) registers are
  // read-only and will be skiped.

  data_sp.reset(new DataBufferHeap(REG_CONTEXT_SIZE, 0));
  // In order to create register data checkpoint we first read all register
  // values if not done already and calculate total size of register set data.
  // We store all register values in data_sp by copying full PTrace data that
  // corresponds to register sets enabled by current register context.

  Status error;
  uint32_t reg_data_byte_size = GetGPRBufferSize();
  error = ReadGPR();
  if (error.Fail())
    return error;

  // If SVE is enabled we need not copy FPR separately.
  if (GetRegisterInfo().IsSVEEnabled()) {
    reg_data_byte_size += GetSVEBufferSize();
    error = ReadAllSVE();
  } else {
    reg_data_byte_size += GetFPRSize();
    error = ReadFPR();
  }
  if (error.Fail())
    return error;

  if (GetRegisterInfo().IsMTEEnabled()) {
    reg_data_byte_size += GetMTEControlSize();
    error = ReadMTEControl();
    if (error.Fail())
      return error;
  }

  data_sp.reset(new DataBufferHeap(reg_data_byte_size, 0));
  uint8_t *dst = data_sp->GetBytes();
  ::memcpy(dst, GetGPRBuffer(), GetGPRSize());
  dst += GetGPRSize();

  ::memcpy(dst, GetGPRBuffer(), GetGPRBufferSize());
  dst += GetGPRBufferSize();

  if (GetRegisterInfo().IsSVEEnabled()) {
    ::memcpy(dst, GetSVEBuffer(), GetSVEBufferSize());
    dst += GetSVEBufferSize();
  } else {
    ::memcpy(dst, GetFPRBuffer(), GetFPRSize());
    dst += GetFPRSize();
  }

  if (GetRegisterInfo().IsMTEEnabled())
    ::memcpy(dst, GetMTEControl(), GetMTEControlSize());

  return error;
}

Status NativeRegisterContextLinux_arm64::WriteAllRegisterValues(
    const lldb::DataBufferSP &data_sp) {
  Status error;
  // AArch64 register data must contain GPRs, either FPR or SVE registers
  // and optional MTE register. Pointer Authentication (PAC) registers are
  // read-only and will be skiped.

  // We store all register values in data_sp by copying full PTrace data that
  // corresponds to register sets enabled by current register context. In order
  // to restore from register data checkpoint we will first restore GPRs, based
  // on size of remaining register data either SVE or FPRs should be restored
  // next. SVE is not enabled if we have register data size less than or equal
  // to size of GPR + FPR + MTE.

  Status error;
  if (!data_sp) {
    error.SetErrorStringWithFormat(
        "NativeRegisterContextLinux_x86_64::%s invalid data_sp provided",
        "NativeRegisterContextLinux_arm64::%s invalid data_sp provided",
        __FUNCTION__);
    return error;
  }

  if (data_sp->GetByteSize() != REG_CONTEXT_SIZE) {
    error.SetErrorStringWithFormat(
        "NativeRegisterContextLinux_x86_64::%s data_sp contained mismatched "
        "data size, expected %" PRIu64 ", actual %" PRIu64,
        __FUNCTION__, REG_CONTEXT_SIZE, data_sp->GetByteSize());
    return error;
  }

  uint8_t *src = data_sp->GetBytes();
  if (src == nullptr) {
    error.SetErrorStringWithFormat("NativeRegisterContextLinux_x86_64::%s "
@@ -499,19 +532,79 @@ Status NativeRegisterContextLinux_arm64::WriteAllRegisterValues(
                                   __FUNCTION__);
    return error;
  }
  ::memcpy(GetGPRBuffer(), src, GetRegisterInfoInterface().GetGPRSize());

  uint64_t reg_data_min_size = GetGPRBufferSize() + GetFPRSize();
  if (data_sp->GetByteSize() < reg_data_min_size) {
    error.SetErrorStringWithFormat(
        "NativeRegisterContextLinux_arm64::%s data_sp contained insufficient "
        "register data bytes, expected at least %" PRIu64 ", actual %" PRIu64,
        __FUNCTION__, reg_data_min_size, data_sp->GetByteSize());
    return error;
  }

  // Register data starts with GPRs
  ::memcpy(GetGPRBuffer(), src, GetGPRBufferSize());
  m_gpr_is_valid = true;

  error = WriteGPR();
  if (error.Fail())
    return error;

  src += GetRegisterInfoInterface().GetGPRSize();
  ::memcpy(GetFPRBuffer(), src, GetFPRSize());
  src += GetGPRBufferSize();

  // Verify if register data may contain SVE register values.
  bool contains_sve_reg_data =
      (data_sp->GetByteSize() > (reg_data_min_size + GetSVEHeaderSize()));

  if (contains_sve_reg_data) {
    // We have SVE register data first write SVE header.
    ::memcpy(GetSVEHeader(), src, GetSVEHeaderSize());
    if (!sve_vl_valid(m_sve_header.vl)) {
      m_sve_header_is_valid = false;
      error.SetErrorStringWithFormat("NativeRegisterContextLinux_arm64::%s "
                                     "Invalid SVE header in data_sp",
                                     __FUNCTION__);
      return error;
    }
    m_sve_header_is_valid = true;
    error = WriteSVEHeader();
    if (error.Fail())
      return error;

    // SVE header has been written configure SVE vector length if needed.
    ConfigureRegisterContext();

    // Make sure data_sp contains sufficient data to write all SVE registers.
    reg_data_min_size = GetGPRBufferSize() + GetSVEBufferSize();
    if (data_sp->GetByteSize() < reg_data_min_size) {
      error.SetErrorStringWithFormat(
          "NativeRegisterContextLinux_arm64::%s data_sp contained insufficient "
          "register data bytes, expected %" PRIu64 ", actual %" PRIu64,
          __FUNCTION__, reg_data_min_size, data_sp->GetByteSize());
      return error;
    }

    ::memcpy(GetSVEBuffer(), src, GetSVEBufferSize());
    m_sve_buffer_is_valid = true;
    error = WriteAllSVE();
    src += GetSVEBufferSize();
  } else {
    ::memcpy(GetFPRBuffer(), src, GetFPRSize());
    m_fpu_is_valid = true;
    error = WriteFPR();
    src += GetFPRSize();
  }

  if (error.Fail())
    return error;

  if (GetRegisterInfo().IsMTEEnabled() &&
      data_sp->GetByteSize() > reg_data_min_size) {
    ::memcpy(GetMTEControl(), src, GetMTEControlSize());
    m_mte_ctrl_is_valid = true;
    error = WriteMTEControl();
  }

  return error;
}

@@ -864,13 +957,6 @@ uint32_t NativeRegisterContextLinux_arm64::CalculateSVEOffset(
  return sve_reg_offset;
}

void *NativeRegisterContextLinux_arm64::GetSVEBuffer() {
  if (m_sve_state == SVEState::FPSIMD)
    return m_sve_ptrace_payload.data() + sve::ptrace_fpsimd_offset;

  return m_sve_ptrace_payload.data();
}

std::vector<uint32_t> NativeRegisterContextLinux_arm64::GetExpeditedRegisters(
    ExpeditedRegs expType) const {
  std::vector<uint32_t> expedited_reg_nums =
+1 −1
Original line number Diff line number Diff line
@@ -139,7 +139,7 @@ private:

  void *GetMTEControl() { return &m_mte_ctrl_reg; }

  void *GetSVEBuffer();
  void *GetSVEBuffer() { return m_sve_ptrace_payload.data(); };

  size_t GetSVEHeaderSize() { return sizeof(m_sve_header); }

+1 −0
Original line number Diff line number Diff line
@@ -110,6 +110,7 @@ public:

  bool IsSVEEnabled() const { return m_opt_regsets.AnySet(eRegsetMaskSVE); }
  bool IsPAuthEnabled() const { return m_opt_regsets.AnySet(eRegsetMaskPAuth); }
  bool IsMTEEnabled() const { return m_opt_regsets.AnySet(eRegsetMaskMTE); }

  bool IsSVEReg(unsigned reg) const;
  bool IsSVEZReg(unsigned reg) const;
+55 −36
Original line number Diff line number Diff line
@@ -17,6 +17,51 @@ class RegisterCommandsTestCase(TestBase):
        self.assertEqual(reg_value.GetByteSize(), expected,
                         'Verify "%s" == %i' % (name, expected))

    def check_sve_regs_read(self, z_reg_size):
        p_reg_size = int(z_reg_size / 8)

        for i in range(32):
            z_regs_value = '{' + \
                ' '.join('0x{:02x}'.format(i + 1)
                         for _ in range(z_reg_size)) + '}'
            self.expect("register read " + 'z%i' %
                        (i), substrs=[z_regs_value])

        p_value_bytes = ['0xff', '0x55', '0x11', '0x01', '0x00']
        for i in range(16):
            p_regs_value = '{' + \
                ' '.join(p_value_bytes[i % 5] for _ in range(p_reg_size)) + '}'
            self.expect("register read " + 'p%i' % (i), substrs=[p_regs_value])

        self.expect("register read ffr", substrs=[p_regs_value])

    def check_sve_regs_read_after_write(self, z_reg_size):
        p_reg_size = int(z_reg_size / 8)

        z_regs_value = '{' + \
            ' '.join(('0x9d' for _ in range(z_reg_size))) + '}'

        p_regs_value = '{' + \
            ' '.join(('0xee' for _ in range(p_reg_size))) + '}'

        for i in range(32):
            self.runCmd('register write ' + 'z%i' %
                        (i) + " '" + z_regs_value + "'")

        for i in range(32):
            self.expect("register read " + 'z%i' % (i), substrs=[z_regs_value])

        for i in range(16):
            self.runCmd('register write ' + 'p%i' %
                        (i) + " '" + p_regs_value + "'")

        for i in range(16):
            self.expect("register read " + 'p%i' % (i), substrs=[p_regs_value])

        self.runCmd('register write ' + 'ffr ' + "'" + p_regs_value + "'")

        self.expect("register read " + 'ffr', substrs=[p_regs_value])

    mydir = TestBase.compute_mydir(__file__)

    @no_debug_info_test
@@ -117,43 +162,17 @@ class RegisterCommandsTestCase(TestBase):

        z_reg_size = vg_reg_value * 8

        p_reg_size = int(z_reg_size / 8)

        for i in range(32):
            z_regs_value = '{' + \
                ' '.join('0x{:02x}'.format(i + 1)
                         for _ in range(z_reg_size)) + '}'
            self.expect("register read " + 'z%i' %
                        (i), substrs=[z_regs_value])
        self.check_sve_regs_read(z_reg_size)

        p_value_bytes = ['0xff', '0x55', '0x11', '0x01', '0x00']
        for i in range(16):
            p_regs_value = '{' + \
                ' '.join(p_value_bytes[i % 5] for _ in range(p_reg_size)) + '}'
            self.expect("register read " + 'p%i' % (i), substrs=[p_regs_value])
        # Evaluate simple expression and print function expr_eval_func address.
        self.expect("p expr_eval_func", substrs=["= 0x"])

        self.expect("register read ffr", substrs=[p_regs_value])
        # Evaluate expression call function expr_eval_func.
        self.expect_expr("expr_eval_func()",
                         result_type="int", result_value="1")

        z_regs_value = '{' + \
            ' '.join(('0x9d' for _ in range(z_reg_size))) + '}'
        # We called a jitted function above which must not have changed SVE
        # vector length or register values.
        self.check_sve_regs_read(z_reg_size)

        p_regs_value = '{' + \
            ' '.join(('0xee' for _ in range(p_reg_size))) + '}'

        for i in range(32):
            self.runCmd('register write ' + 'z%i' %
                        (i) + " '" + z_regs_value + "'")

        for i in range(32):
            self.expect("register read " + 'z%i' % (i), substrs=[z_regs_value])

        for i in range(16):
            self.runCmd('register write ' + 'p%i' %
                        (i) + " '" + p_regs_value + "'")

        for i in range(16):
            self.expect("register read " + 'p%i' % (i), substrs=[p_regs_value])

        self.runCmd('register write ' + 'ffr ' + "'" + p_regs_value + "'")

        self.expect("register read " + 'ffr', substrs=[p_regs_value])
        self.check_sve_regs_read_after_write(z_reg_size)
+18 −1
Original line number Diff line number Diff line
int main() {
#include <sys/prctl.h>

void write_sve_regs() {
  asm volatile("setffr\n\t");
  asm volatile("ptrue p0.b\n\t");
  asm volatile("ptrue p1.h\n\t");
@@ -49,5 +51,20 @@ int main() {
  asm volatile("cpy  z29.b, p5/z, #30\n\t");
  asm volatile("cpy  z30.b, p10/z, #31\n\t");
  asm volatile("cpy  z31.b, p15/z, #32\n\t");
}

// This function will be called using jitted expression call. We change vector
// length and write SVE registers. Our program context should restore to
// orignal vector length and register values after expression evaluation.
int expr_eval_func() {
  prctl(PR_SVE_SET_VL, 8 * 2);
  write_sve_regs();
  prctl(PR_SVE_SET_VL, 8 * 4);
  write_sve_regs();
  return 1;
}

int main() {
  write_sve_regs();
  return 0; // Set a break point here.
}