Unverified Commit 8c3dbe42 authored by Björn Grüning's avatar Björn Grüning Committed by GitHub
Browse files

Merge pull request #20754 from nilchia/safetensors_dt

[25.0] Add safetensors datatype
parents f9032e47 9896e9b4
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -1173,6 +1173,7 @@
    <datatype extension="bcsp" type="galaxy.datatypes.binary:Binary" mimetype="application/octet-stream" display_in_upload="true" subclass="true" description="Binary format of k-mer hash table which is only compatible with Fairy"/>
    <!-- rdeval types -->
    <datatype extension="rd" type="galaxy.datatypes.binary:Binary" mimetype="application/octet-stream" display_in_upload="true" subclass="true" description="Rdeval read sketch"/>
    <datatype extension="safetensors" type="galaxy.datatypes.binary:Safetensors" mimetype="application/octet-stream" display_in_upload="true" description="A simple format for storing tensors safely (as opposed to pickle) and that is still fast (zero-copy)" description_url="https://huggingface.co/docs/safetensors/index"/>
  </registration>

  <sniffers>
+99 −0
Original line number Diff line number Diff line
@@ -4864,3 +4864,102 @@ class Hic(Binary):
        with open(dataset.get_file_name(), "rb") as handle:
            header_bytes = handle.read(8)
        dataset.metadata.version = struct.unpack("<i", header_bytes[4:8])[0]


@build_sniff_from_prefix
class Safetensors(Binary):
    """
    safetensors is a new simple format for storing tensors safely (as opposed to pickle) and that is still fast (zero-copy).
    It provides a secure way to store and load tensors without the security risks associated with pickle-based formats.
    Safetensors files consist of a JSON header followed by tensor data.
    more info at: https://github.com/huggingface/safetensors
    """

    file_ext = "safetensors"

    def sniff_prefix(self, file_prefix: FilePrefix) -> bool:
        """
        Determining if the file is in safetensors format
        >>> from galaxy.datatypes.sniff import get_test_fname
        >>> fname = get_test_fname('cellpose_model_safetensors.safetensors')
        >>> Safetensors().sniff(fname)
        True
        >>> fname = get_test_fname('test_charmm.vel')
        >>> Safetensors().sniff(fname)
        False
        """
        try:
            # Safetensors files start with an 8-byte little-endian integer
            # indicating the size of the JSON header
            if len(file_prefix.contents_header_bytes) < 8:
                return False

            header_size = int.from_bytes(file_prefix.contents_header_bytes[:8], "little")

            # Currently, there's a limit on the size of the header of 100MB to prevent parsing extremely large JSON headers
            # In practice, safetensors headers are typically just a few KB to MB
            # (containing tensor names, shapes, dtypes, and offsets - rarely exceeds 1-10MB even for large models)
            # But in theory it is possible to have 100 MB header
            # more info here: https://github.com/huggingface/safetensors?tab=readme-ov-file#benefits
            if header_size == 0 or header_size > 10**8:  # 100MB max for JSON header
                return False

            # Check if file is large enough to contain the full header
            if file_prefix.file_size < 8 + header_size:
                return False

            # CRITICAL: Check if header begins with '{' character (0x7B) as per safetensors spec
            # This is required by the format and helps distinguish from other binary formats
            # Only check 1 byte to avoid issues with malicious header_size values
            # more info here: https://github.com/huggingface/safetensors?tab=readme-ov-file#format
            if file_prefix.contents_header_bytes[8] != 0x7B:
                return False

            # Check if header ends with '}' character (0x7D) as per safetensors spec
            # This requires reading more data if header extends beyond the prefix
            header_end_pos = 8 + header_size - 1
            if header_end_pos < len(file_prefix.contents_header_bytes):
                # Header end is within the prefix
                if file_prefix.contents_header_bytes[header_end_pos] != 0x7D:
                    return False
            else:
                # Header extends beyond prefix, need to check from file
                with open(file_prefix.filename, "rb") as f:
                    f.seek(header_end_pos)
                    last_header_byte = f.read(1)
                    if len(last_header_byte) != 1 or last_header_byte[0] != 0x7D:
                        return False

            # Read the full header for JSON parsing
            if 8 + header_size <= len(file_prefix.contents_header_bytes):
                # Entire header is in the prefix
                header_bytes = file_prefix.contents_header_bytes[8 : 8 + header_size]
            else:
                # Need to read full header from file
                with open(file_prefix.filename, "rb") as f:
                    f.seek(8)
                    header_bytes = f.read(header_size)

            if len(header_bytes) != header_size:
                return False

            # Parse the validated JSON header
            header = json.loads(header_bytes.decode("utf-8"))
            # check if header is a dict
            if not isinstance(header, dict):
                return False
            # Basic validation: check if it looks like safetensors metadata
            # Safetensors headers should have entries with data_offsets
            has_valid_entries = False
            for key, value in header.items():
                if key == "__metadata__":  # Special metadata key
                    continue
                if isinstance(value, dict) and "data_offsets" in value:
                    has_valid_entries = True
                    break

            return has_valid_entries

        except Exception:
            # Any exception during parsing means it's not a valid safetensors file
            return False
+76 B

File added.

No diff preview for this file type.