Commit 98856b22 authored by Lei Zhang's avatar Lei Zhang
Browse files

[mlir][spirv] Update SPIR-V enums and ops with availability spec

This commit updates gen_spirv_dialect.py to query the grammar and
generate availability spec for various enum attribute definitions
and all defined ops.

Reviewed By: mravishankar

Differential Revision: https://reviews.llvm.org/D72095
parent 838f53ed
Loading
Loading
Loading
Loading
+1619 −308

File changed.

Preview size limit exceeded, changes collapsed.

+77 −0
Original line number Diff line number Diff line
@@ -88,6 +88,13 @@ def SPV_BitCountOp : SPV_BitUnaryOp<"BitCount", []> {
    %3 = spv.BitCount %1: vector<4xi32>
    ```
  }];

  let availability = [
    MinVersion<SPV_V_1_0>,
    MaxVersion<SPV_V_1_5>,
    Extension<[]>,
    Capability<[]>
  ];
}

// -----
@@ -139,6 +146,13 @@ def SPV_BitFieldInsertOp : SPV_Op<"BitFieldInsert", [NoSideEffect]> {
    ```
  }];

  let availability = [
    MinVersion<SPV_V_1_0>,
    MaxVersion<SPV_V_1_5>,
    Extension<[]>,
    Capability<[SPV_C_Shader]>
  ];

  let arguments = (ins
    SPV_ScalarOrVectorOf<SPV_Integer>:$base,
    SPV_ScalarOrVectorOf<SPV_Integer>:$insert,
@@ -196,6 +210,13 @@ def SPV_BitFieldSExtractOp : SPV_BitFieldExtractOp<"BitFieldSExtract", []> {
    %0 = spv.BitFieldSExtract %base, %offset, %count : vector<3xi32>, i8, i8
    ```
  }];

  let availability = [
    MinVersion<SPV_V_1_0>,
    MaxVersion<SPV_V_1_5>,
    Extension<[]>,
    Capability<[SPV_C_Shader]>
  ];
}

// -----
@@ -225,6 +246,13 @@ def SPV_BitFieldUExtractOp : SPV_BitFieldExtractOp<"BitFieldUExtract", []> {
    %0 = spv.BitFieldUExtract %base, %offset, %count : vector<3xi32>, i8, i8
    ```
  }];

  let availability = [
    MinVersion<SPV_V_1_0>,
    MaxVersion<SPV_V_1_5>,
    Extension<[]>,
    Capability<[SPV_C_Shader]>
  ];
}

// -----
@@ -258,6 +286,13 @@ def SPV_BitReverseOp : SPV_BitUnaryOp<"BitReverse", []> {
    %3 = spv.BitReverse %1 : vector<4xi32>
    ```
  }];

  let availability = [
    MinVersion<SPV_V_1_0>,
    MaxVersion<SPV_V_1_5>,
    Extension<[]>,
    Capability<[SPV_C_Shader]>
  ];
}

// -----
@@ -292,6 +327,13 @@ def SPV_BitwiseAndOp : SPV_BitBinaryOp<"BitwiseAnd", [Commutative]> {
    %2 = spv.BitwiseAnd %0, %1 : vector<4xi32>
    ```
  }];

  let availability = [
    MinVersion<SPV_V_1_0>,
    MaxVersion<SPV_V_1_5>,
    Extension<[]>,
    Capability<[]>
  ];
}

// -----
@@ -326,6 +368,13 @@ def SPV_BitwiseOrOp : SPV_BitBinaryOp<"BitwiseOr", [Commutative]> {
    %2 = spv.BitwiseOr %0, %1 : vector<4xi32>
    ```
  }];

  let availability = [
    MinVersion<SPV_V_1_0>,
    MaxVersion<SPV_V_1_5>,
    Extension<[]>,
    Capability<[]>
  ];
}

// -----
@@ -360,6 +409,13 @@ def SPV_BitwiseXorOp : SPV_BitBinaryOp<"BitwiseXor", [Commutative]> {
    %2 = spv.BitwiseXor %0, %1 : vector<4xi32>
    ```
  }];

  let availability = [
    MinVersion<SPV_V_1_0>,
    MaxVersion<SPV_V_1_5>,
    Extension<[]>,
    Capability<[]>
  ];
}

// -----
@@ -404,6 +460,13 @@ def SPV_ShiftLeftLogicalOp : SPV_ShiftOp<"ShiftLeftLogical", []> {
    %5 = spv.ShiftLeftLogical %3, %4 : vector<3xi32>, vector<3xi16>
    ```
  }];

  let availability = [
    MinVersion<SPV_V_1_0>,
    MaxVersion<SPV_V_1_5>,
    Extension<[]>,
    Capability<[]>
  ];
}

// -----
@@ -445,6 +508,13 @@ def SPV_ShiftRightArithmeticOp : SPV_ShiftOp<"ShiftRightArithmetic", []> {
    %5 = spv.ShiftRightArithmetic %3, %4 : vector<3xi32>, vector<3xi16>
    ```
  }];

  let availability = [
    MinVersion<SPV_V_1_0>,
    MaxVersion<SPV_V_1_5>,
    Extension<[]>,
    Capability<[]>
  ];
}

// -----
@@ -487,6 +557,13 @@ def SPV_ShiftRightLogicalOp : SPV_ShiftOp<"ShiftRightLogical", []> {
    %5 = spv.ShiftRightLogical %3, %4 : vector<3xi32>, vector<3xi16>
    ```
  }];

  let availability = [
    MinVersion<SPV_V_1_0>,
    MaxVersion<SPV_V_1_5>,
    Extension<[]>,
    Capability<[]>
  ];
}

// -----
+7 −0
Original line number Diff line number Diff line
@@ -49,6 +49,13 @@ def SPV_SubgroupBallotKHROp : SPV_Op<"SubgroupBallotKHR", []> {
    ```
  }];

  let availability = [
    MinVersion<SPV_V_1_0>,
    MaxVersion<SPV_V_1_5>,
    Extension<[SPV_KHR_shader_ballot]>,
    Capability<[SPV_C_SubgroupBallotKHR]>
  ];

  let arguments = (ins
    SPV_Bool:$predicate
  );
+1 −1
Original line number Diff line number Diff line
@@ -35,7 +35,7 @@ func @module_logical_glsl450() {
  // CHECK: spv.module min version: V_1_0
  // CHECK: spv.module max version: V_1_5
  // CHECK: spv.module extensions: [ ]
  // CHECK: spv.module capabilities: [ ]
  // CHECK: spv.module capabilities: [ [Shader] ]
  spv.module "Logical" "GLSL450" { }
  return
}
+149 −21
Original line number Diff line number Diff line
@@ -111,9 +111,11 @@ def uniquify_enum_cases(lst):
   - A list with all duplicates removed. The elements are sorted according to
     value and, for each value, uniqued according to symbol.
     original list,
   - A map from deduplicated cases to the uniqued case.
  """
  cases = lst
  uniqued_cases = []
  duplicated_cases = {}

  # First sort according to the value
  cases.sort(key=lambda x: x[1])
@@ -125,14 +127,110 @@ def uniquify_enum_cases(lst):
    # Keep the "smallest" case, which is typically the symbol without extension
    # suffix. But we have special cases that we want to fix.
    case = sorted_group[0]
    for i in range(1, len(sorted_group)):
      duplicated_cases[sorted_group[i][0]] = case[0]
    if case[0] == 'HlslSemanticGOOGLE':
      assert len(sorted_group) == 2, 'unexpected new variant for HlslSemantic'
      case = sorted_group[1]
      duplicated_cases[sorted_group[0][0]] = case[0]
    uniqued_cases.append(case)

  return uniqued_cases
  return uniqued_cases, duplicated_cases


def gen_operand_kind_enum_attr(operand_kind):
def get_capability_mapping(operand_kinds):
  """Returns the capability mapping from duplicated cases to their canonicalized

  case.

  Arguments:
    - operand_kinds: all operand kinds' grammar spec

  Returns:
    - A map mapping from duplicated capability symbols to the canonicalized
      symbol chosen for SPIRVBase.td.
  """
  # Find the operand kind for capability
  cap_kind = {}
  for kind in operand_kinds:
    if kind['kind'] == 'Capability':
      cap_kind = kind

  kind_cases = [
      (case['enumerant'], case['value']) for case in cap_kind['enumerants']
  ]
  _, capability_mapping = uniquify_enum_cases(kind_cases)

  return capability_mapping


def get_availability_spec(enum_case, capability_mapping, for_op):
  """Returns the availability specification string for the given enum case.

  Arguments:
    - enum_case: the enum case to generate availability spec for. It may contain
      'version', 'lastVersion', 'extensions', or 'capabilities'.
    - capability_mapping: mapping from duplicated capability symbols to the
      canonicalized symbol chosen for SPIRVBase.td.
    - for_op: bool value indicating whether this is the availability spec for an
      op itself.

  Returns:
    - A `let availability = [...];` string if with availability spec or
      empty string if without availability spec
  """
  min_version = enum_case.get('version', '')
  if min_version == 'None':
    min_version = ''
  elif min_version:
    min_version = 'MinVersion<SPV_V_{}>'.format(min_version.replace('.', '_'))
  # TODO(antiagainst): delete this once ODS can support dialect-specific content
  # and we can use omission to mean no requirements.
  if for_op and not min_version:
    min_version = 'MinVersion<SPV_V_1_0>'

  max_version = enum_case.get('lastVersion', '')
  if max_version:
    max_version = 'MaxVersion<SPV_V_{}>'.format(max_version.replace('.', '_'))
  # TODO(antiagainst): delete this once ODS can support dialect-specific content
  # and we can use omission to mean no requirements.
  if for_op and not max_version:
    max_version = 'MaxVersion<SPV_V_1_5>'

  exts = enum_case.get('extensions', [])
  if exts:
    exts = 'Extension<[{}]>'.format(', '.join(sorted(set(exts))))
  # TODO(antiagainst): delete this once ODS can support dialect-specific content
  # and we can use omission to mean no requirements.
  if for_op and not exts:
    exts = 'Extension<[]>'

  caps = enum_case.get('capabilities', [])
  if caps:
    canonicalized_caps = []
    for c in caps:
      if c in capability_mapping:
        canonicalized_caps.append(capability_mapping[c])
      else:
        canonicalized_caps.append(c)
    caps = 'Capability<[{}]>'.format(', '.join(
        ['SPV_C_{}'.format(c) for c in sorted(set(canonicalized_caps))]))
  # TODO(antiagainst): delete this once ODS can support dialect-specific content
  # and we can use omission to mean no requirements.
  if for_op and not caps:
    caps = 'Capability<[]>'

  avail = ''
  if min_version or max_version or caps or exts:
    joined_spec = ',\n    '.join(
        [e for e in [min_version, max_version, exts, caps] if e])
    avail = '{} availability = [\n    {}\n  ];'.format(
        'let' if for_op else 'list<Availability>', joined_spec)

  return avail


def gen_operand_kind_enum_attr(operand_kind, capability_mapping):
  """Generates the TableGen EnumAttr definition for the given operand kind.

  Returns:
@@ -155,24 +253,37 @@ def gen_operand_kind_enum_attr(operand_kind):
  is_bit_enum = operand_kind['category'] == 'BitEnum'
  kind_category = 'Bit' if is_bit_enum else 'I32'
  kind_acronym = ''.join([c for c in kind_name if c >= 'A' and c <= 'Z'])

  name_to_case_dict = {}
  for case in operand_kind['enumerants']:
    name_to_case_dict[case['enumerant']] = case

  kind_cases = [(case['enumerant'], case['value'])
                for case in operand_kind['enumerants']]
  kind_cases = uniquify_enum_cases(kind_cases)
  kind_cases, _ = uniquify_enum_cases(kind_cases)
  max_len = max([len(symbol) for (symbol, _) in kind_cases])

  # Generate the definition for each enum case
  fmt_str = 'def SPV_{acronym}_{case} {colon:>{offset}} '\
            '{category}EnumAttrCase<"{symbol}", {value}>;'
  case_defs = [
      fmt_str.format(
            '{category}EnumAttrCase<"{symbol}", {value}>{avail}'
  case_defs = []
  for case in kind_cases:
    if kind_name == 'Capability':
      avail = ''
    else:
      avail = get_availability_spec(name_to_case_dict[case[0]],
                                    capability_mapping,
                                    False)
    case_def = fmt_str.format(
        category=kind_category,
        acronym=kind_acronym,
        case=case[0],
        symbol=get_case_symbol(kind_name, case[0]),
        value=case[1],
        avail=' {{\n  {}\n}}'.format(avail) if avail else ';',
        colon=':',
          offset=(max_len + 1 - len(case[0]))) for case in kind_cases
  ]
        offset=(max_len + 1 - len(case[0])))
    case_defs.append(case_def)
  case_defs = '\n'.join(case_defs)

  # Generate the list of enum case names
@@ -287,9 +398,14 @@ def update_td_enum_attrs(path, operand_kinds, filter_list):
      k[8:-4] for k in re.findall('def SPV_\w+Attr', content[1])]
  filter_list.extend(existing_kinds)

  capability_mapping = get_capability_mapping(operand_kinds)

  # Generate definitions for all enums in filter list
  defs = [gen_operand_kind_enum_attr(kind)
          for kind in operand_kinds if kind['kind'] in filter_list]
  defs = [
      gen_operand_kind_enum_attr(kind, capability_mapping)
      for kind in operand_kinds
      if kind['kind'] in filter_list
  ]
  # Sort alphabetically according to enum name
  defs.sort(key=lambda enum : enum[0])
  # Only keep the definitions from now on
@@ -387,7 +503,7 @@ def get_description(text, assembly):
      text=text, assembly=assembly)


def get_op_definition(instruction, doc, existing_info):
def get_op_definition(instruction, doc, existing_info, capability_mapping):
  """Generates the TableGen op definition for the given SPIR-V instruction.

  Arguments:
@@ -395,6 +511,8 @@ def get_op_definition(instruction, doc, existing_info):
    - doc: the instruction's SPIR-V HTML doc
    - existing_info: a dict containing potential manually specified sections for
      this instruction
    - capability_mapping: mapping from duplicated capability symbols to the
                   canonicalized symbol chosen for SPIRVBase.td

  Returns:
    - A string containing the TableGen op definition
@@ -402,7 +520,7 @@ def get_op_definition(instruction, doc, existing_info):
  fmt_str = ('def SPV_{opname}Op : '
             'SPV_{inst_category}<"{opname}"{category_args}[{traits}]> '
             '{{\n  let summary = {summary};\n\n  let description = '
             '[{{\n{description}}}];\n')
             '[{{\n{description}}}];{availability}\n')
  inst_category = existing_info.get('inst_category', 'Op')
  if inst_category == 'Op':
    fmt_str +='\n  let arguments = (ins{args});\n\n'\
@@ -439,6 +557,11 @@ def get_op_definition(instruction, doc, existing_info):

  operands = instruction.get('operands', [])

  # Op availability
  avail = get_availability_spec(instruction, capability_mapping, True)
  if avail:
    avail = '\n\n  {0}'.format(avail)

  # Set op's result
  results = ''
  if len(operands) > 0 and operands[0]['kind'] == 'IdResultType':
@@ -478,6 +601,7 @@ def get_op_definition(instruction, doc, existing_info):
      traits=existing_info.get('traits', ''),
      summary=summary,
      description=description,
      availability=avail,
      args=arguments,
      results=results,
      extras=existing_info.get('extras', ''))
@@ -600,7 +724,7 @@ def extract_td_op_info(op_def):


def update_td_op_definitions(path, instructions, docs, filter_list,
                             inst_category):
                             inst_category, capability_mapping):
  """Updates SPIRVOps.td with newly generated op definition.

  Arguments:
@@ -608,6 +732,8 @@ def update_td_op_definitions(path, instructions, docs, filter_list,
    - instructions: SPIR-V JSON grammar for all instructions
    - docs: SPIR-V HTML doc for all instructions
    - filter_list: a list containing new opnames to include
    - capability_mapping: mapping from duplicated capability symbols to the
                   canonicalized symbol chosen for SPIRVBase.td.

  Returns:
    - A string containing all the TableGen op definitions
@@ -643,7 +769,8 @@ def update_td_op_definitions(path, instructions, docs, filter_list,
      op_defs.append(
          get_op_definition(
              instruction, docs[opname],
              op_info_dict.get(opname, {'inst_category': inst_category})))
              op_info_dict.get(opname, {'inst_category': inst_category}),
              capability_mapping))
    except StopIteration:
      # This is an op added by us; use the existing ODS definition.
      op_defs.append(name_op_map[opname])
@@ -722,8 +849,9 @@ if __name__ == '__main__':
  if args.new_inst is not None:
    assert args.op_td_path is not None
    docs = get_spirv_doc_from_html_spec()
    capability_mapping = get_capability_mapping(operand_kinds)
    update_td_op_definitions(args.op_td_path, instructions, docs, args.new_inst,
                             args.inst_category)
                             args.inst_category, capability_mapping)
    print('Done. Note that this script just generates a template; ', end='')
    print('please read the spec and update traits, arguments, and ', end='')
    print('results accordingly.')