package internal import ( "bytes" "encoding/json" "errors" "io" "sync" "github.com/pion/webrtc/v2" ) type DataChannel struct { dataChannel *webrtc.DataChannel rawDataLock sync.RWMutex rawDataWriter *io.PipeWriter rawDataReader *io.PipeReader rawDataRest uint64 bufferWaitChannel chan interface{} closeWaitChannel chan interface{} onClose func() onOpen func() onSendaroundMessage func(*Message) onError func(error) } func NewDataChannel(dataChannel *webrtc.DataChannel) *DataChannel { dc := &DataChannel{ dataChannel: dataChannel, } dc.init() return dc } func (conn *DataChannel) OnClose(f func()) { conn.onClose = f } func (conn *DataChannel) OnOpen(f func()) { conn.onOpen = f } func (conn *DataChannel) OnSendaroundMessage(f func(*Message)) { conn.onSendaroundMessage = f } func (conn *DataChannel) OnError(f func(error)) { conn.onError = f } func (conn *DataChannel) Close() { // already closed? if conn.dataChannel.ReadyState() != webrtc.DataChannelStateClosing && conn.dataChannel.ReadyState() != webrtc.DataChannelStateClosed { conn.AbortRawDataTransmission(io.EOF) conn.dataChannel.Close() } if conn.dataChannel.ReadyState() != webrtc.DataChannelStateClosed { <-conn.closeWaitChannel } } func (conn *DataChannel) triggerOpen() { if conn.onOpen != nil { conn.onOpen() } } func (conn *DataChannel) triggerClose() { if conn.onClose != nil { conn.onClose() } } func (conn *DataChannel) triggerSendaroundMessage(msg *Message) { if conn.onSendaroundMessage != nil { conn.onSendaroundMessage(msg) } } func (conn *DataChannel) triggerError(err error) { if conn.onError != nil { conn.onError(err) } conn.dataChannel.Close() } func (conn *DataChannel) init() { conn.bufferWaitChannel = make(chan interface{}, 1) conn.closeWaitChannel = make(chan interface{}, 1) conn.dataChannel.OnOpen(func() { conn.triggerOpen() }) conn.dataChannel.SetBufferedAmountLowThreshold(1 * 1024 * 1024) conn.dataChannel.OnBufferedAmountLow(func() { select { case conn.bufferWaitChannel <- nil: default: } }) conn.dataChannel.OnMessage(func(rtcMessage webrtc.DataChannelMessage) { if !rtcMessage.IsString { conn.rawDataLock.RLock() defer conn.rawDataLock.RUnlock() if conn.rawDataRest == 0 { conn.triggerError(ErrUnexpectedRawDataTransmission) return } rawDataSize := uint64(len(rtcMessage.Data)) if conn.rawDataRest < rawDataSize { conn.triggerError(ErrUnexpectedOverlongRawData) return } _, err := conn.rawDataWriter.Write(rtcMessage.Data) if err != nil { conn.triggerError(err) } conn.rawDataRest -= rawDataSize if conn.rawDataRest == 0 { conn.rawDataWriter.CloseWithError(nil) conn.rawDataWriter = nil } return } m := new(Message) if err := json.NewDecoder(bytes.NewReader(rtcMessage.Data)).Decode(m); err != nil { conn.triggerError(ErrProtocolViolation) return } conn.triggerSendaroundMessage(m) }) conn.dataChannel.OnClose(func() { conn.closeWaitChannel <- nil conn.triggerClose() }) } func (conn *DataChannel) AbortRawDataTransmission(err error) { conn.rawDataLock.Lock() defer conn.rawDataLock.Unlock() if conn.rawDataRest > 0 { conn.rawDataRest = 0 } if conn.rawDataWriter != nil { conn.rawDataWriter.CloseWithError(err) conn.rawDataWriter = nil } if conn.rawDataReader != nil { conn.rawDataReader = nil } } func (conn *DataChannel) ExpectRawData(expectedLength uint64) { conn.rawDataLock.Lock() defer conn.rawDataLock.Unlock() conn.rawDataReader, conn.rawDataWriter = io.Pipe() conn.rawDataRest = expectedLength } func (conn *DataChannel) SendMessage(msg *Message) (err error) { if msg == nil { err = errors.New("msg must not be nil") return } buf := new(bytes.Buffer) if err = json.NewEncoder(buf).Encode(msg); err != nil { return } err = conn.dataChannel.SendText(buf.String()) if err != nil { panic(err) } return } func (conn *DataChannel) Read(p []byte) (n int, err error) { if conn.rawDataReader == nil { err = errors.New("Unexpected raw data read") return } n, err = conn.rawDataReader.Read(p) return } func (conn *DataChannel) Write(p []byte) (n int, err error) { err = conn.dataChannel.Send(p) if err == nil { n = len(p) } // Do not queue too much to avoid memory flood, wait for buffer to be empty enough if conn.dataChannel.BufferedAmount() > conn.dataChannel.BufferedAmountLowThreshold()*10 { <-conn.bufferWaitChannel } return }