Commit b1413673 authored by Ferreira Da Silva, Rafael's avatar Ferreira Da Silva, Rafael
Browse files

improving documentation

parent 9f47ffa5
Loading
Loading
Loading
Loading

train.py

100755 → 100644
+47 −2
Original line number Diff line number Diff line
import zmq
import copy
import os, yaml, argparse
import warnings
@@ -16,6 +17,7 @@ import torch.distributed as dist
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import itertools
import time
import json

warnings.filterwarnings('ignore')
from pf import TimeSeriesDataSet, Baseline, TemporalFusionTransformer
@@ -243,9 +245,52 @@ if __name__ == "__main__":
    print0('-- encoder_length: {}'.format(encoder_length))
    print0('-- prediction_length: {}'.format(prediction_length))

    ## zmq
    context = zmq.Context()
    # subscription socket where we'll receive data
    subsocket = context.socket(zmq.SUB)
    subsocket.setsockopt_string(zmq.SUBSCRIBE, "metadata")
    subsocket.setsockopt_string(zmq.SUBSCRIBE, "data")
    subsocket.setsockopt_string(zmq.SUBSCRIBE, "END")
    subsocket.connect("tcp://login11.frontier.olcf.ornl.gov:5555")

    # let the publisher know we are established
    sigsocket = context.socket(zmq.REQ)
    sigsocket.connect("tcp://login11.frontier.olcf.ornl.gov:5556")
    sigsocket.send(b'')
    sigsocket.recv()

    print("Synchronized; ready to receive data")
    def receive_byte_array(socket):
        chunks = []
        while True:
            try:
                topic, chunk = socket.recv_multipart()
                print(f"Received message with topic: {topic.decode()}")  # Debugging print
                if topic.decode() == "END":
                    break
                elif topic.decode() == "data":
                    chunks.append(chunk)
            except Exception as e:
                print(f"Error receiving chunk: {e}")
                break
        return b''.join(chunks)

    # Receive metadata
    topic, metadata_str = subsocket.recv_multipart()
    md = json.loads(metadata_str)
    print(f"Metadata received: {md}")

    byte_array = receive_byte_array(subsocket)

    ary = np.frombuffer(byte_array, dtype=np.dtype(md['dtype']))
    ary = ary.reshape(md['shape'])
    print(f'Received {ary.size} elements of type {ary.dtype}')

    ## Preprocess training sequences
    data_path = config['data_path']
    data_np = np.load(data_path)  # lev1_data
    #data_path = config['data_path']
    #data_np = np.load(data_path)  # lev1_data
    data_np = ary

    ## Prepare the lev1, lev2 training data, for lev2, we have different strategies based on the mapping_mode
    if config['mapping_mode'] == 0:  # single rank for level 2