Commit 32da96a8 authored by Ferreira Da Silva, Rafael's avatar Ferreira Da Silva, Rafael
Browse files

Update train.py

parent cc87600a
Loading
Loading
Loading
Loading
+7 −2
Original line number Diff line number Diff line
@@ -255,13 +255,14 @@ if __name__ == "__main__":
    subsocket.setsockopt_string(zmq.SUBSCRIBE, "END")
    subsocket.connect("tcp://REPLACE_HOST:5560")

    print("Synchronized; ready to receive data")
    print(f"Rank-{torch.distributed.get_rank()}: Synchronized; ready to receive data")
    start_time = datetime.now()
    def receive_byte_array(socket):
        chunks = []
        while True:
            try:
                topic, chunk = socket.recv_multipart()
                print(f"Received message with topic: {topic.decode()}")  # Debugging print
                # print(f"Received message with topic: {topic.decode()}")  # Debugging print
                if topic.decode() == "END":
                    break
                elif topic.decode() == "data":
@@ -277,11 +278,15 @@ if __name__ == "__main__":
    print(f"Metadata received: {md}")

    byte_array = receive_byte_array(subsocket)
    end_time = datetime.now()

    ary = np.frombuffer(byte_array, dtype=np.dtype(md['dtype']))
    ary = ary.reshape(md['shape'])
    print(f'Received {ary.size} elements of type {ary.dtype}')

    print(f"[Rank-{torch.distributed.get_rank()}] Time receiving data: {(end_time - start_time).total_seconds()} seconds")
    print(f"[Rank-{torch.distributed.get_rank()}] Transfer rate: {(ary.size * ary.itemsize) / 125000 / (end_time - start_time).total_seconds()} Mbps")

    ## Preprocess training sequences
    data_np = ary