223 lines
4.6 KiB
Go
223 lines
4.6 KiB
Go
|
package internal
|
||
|
|
||
|
import (
|
||
|
"bufio"
|
||
|
"bytes"
|
||
|
"encoding/json"
|
||
|
"errors"
|
||
|
"io"
|
||
|
"sync"
|
||
|
|
||
|
"github.com/pion/webrtc/v2"
|
||
|
)
|
||
|
|
||
|
type DataChannel struct {
|
||
|
dataChannel *webrtc.DataChannel
|
||
|
|
||
|
rawDataLock sync.RWMutex
|
||
|
rawDataWriter *io.PipeWriter
|
||
|
rawDataReaderBuffered *bufio.Reader
|
||
|
rawDataReader *io.PipeReader
|
||
|
rawDataRest uint64
|
||
|
|
||
|
bufferWaitChannel chan interface{}
|
||
|
closeWaitChannel chan interface{}
|
||
|
|
||
|
onClose func()
|
||
|
onOpen func()
|
||
|
onSendaroundMessage func(*Message)
|
||
|
onError func(error)
|
||
|
}
|
||
|
|
||
|
func NewDataChannel(dataChannel *webrtc.DataChannel) *DataChannel {
|
||
|
dc := &DataChannel{
|
||
|
dataChannel: dataChannel,
|
||
|
}
|
||
|
dc.init()
|
||
|
return dc
|
||
|
}
|
||
|
|
||
|
func (conn *DataChannel) OnClose(f func()) {
|
||
|
conn.onClose = f
|
||
|
}
|
||
|
|
||
|
func (conn *DataChannel) OnOpen(f func()) {
|
||
|
conn.onOpen = f
|
||
|
}
|
||
|
|
||
|
func (conn *DataChannel) OnSendaroundMessage(f func(*Message)) {
|
||
|
conn.onSendaroundMessage = f
|
||
|
}
|
||
|
|
||
|
func (conn *DataChannel) OnError(f func(error)) {
|
||
|
conn.onError = f
|
||
|
}
|
||
|
|
||
|
func (conn *DataChannel) Close() {
|
||
|
// already closed?
|
||
|
if conn.dataChannel.ReadyState() == webrtc.DataChannelStateClosing ||
|
||
|
conn.dataChannel.ReadyState() == webrtc.DataChannelStateClosed {
|
||
|
return
|
||
|
}
|
||
|
|
||
|
conn.AbortRawDataTransmission(io.EOF)
|
||
|
conn.dataChannel.Close()
|
||
|
<-conn.closeWaitChannel
|
||
|
}
|
||
|
|
||
|
func (conn *DataChannel) triggerOpen() {
|
||
|
if conn.onOpen != nil {
|
||
|
conn.onOpen()
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (conn *DataChannel) triggerClose() {
|
||
|
if conn.onClose != nil {
|
||
|
conn.onClose()
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (conn *DataChannel) triggerSendaroundMessage(msg *Message) {
|
||
|
if conn.onSendaroundMessage != nil {
|
||
|
conn.onSendaroundMessage(msg)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (conn *DataChannel) triggerError(err error) {
|
||
|
if conn.onError != nil {
|
||
|
conn.onError(err)
|
||
|
}
|
||
|
conn.dataChannel.Close()
|
||
|
}
|
||
|
|
||
|
func (conn *DataChannel) init() {
|
||
|
conn.bufferWaitChannel = make(chan interface{}, 1)
|
||
|
conn.closeWaitChannel = make(chan interface{}, 1)
|
||
|
|
||
|
conn.dataChannel.OnOpen(func() {
|
||
|
conn.triggerOpen()
|
||
|
})
|
||
|
|
||
|
conn.dataChannel.SetBufferedAmountLowThreshold(1 * 1024 * 1024)
|
||
|
conn.dataChannel.OnBufferedAmountLow(func() {
|
||
|
select {
|
||
|
case conn.bufferWaitChannel <- nil:
|
||
|
default:
|
||
|
}
|
||
|
})
|
||
|
|
||
|
conn.dataChannel.OnMessage(func(rtcMessage webrtc.DataChannelMessage) {
|
||
|
if !rtcMessage.IsString {
|
||
|
conn.rawDataLock.RLock()
|
||
|
defer conn.rawDataLock.RUnlock()
|
||
|
|
||
|
if conn.rawDataRest == 0 {
|
||
|
conn.triggerError(ErrUnexpectedRawDataTransmission)
|
||
|
return
|
||
|
}
|
||
|
|
||
|
rawDataSize := uint64(len(rtcMessage.Data))
|
||
|
if conn.rawDataRest < rawDataSize {
|
||
|
conn.triggerError(ErrUnexpectedOverlongRawData)
|
||
|
return
|
||
|
}
|
||
|
|
||
|
_, err := conn.rawDataWriter.Write(rtcMessage.Data)
|
||
|
if err != nil {
|
||
|
conn.triggerError(err)
|
||
|
}
|
||
|
|
||
|
conn.rawDataRest -= rawDataSize
|
||
|
if conn.rawDataRest == 0 {
|
||
|
conn.rawDataWriter.CloseWithError(nil)
|
||
|
conn.rawDataWriter = nil
|
||
|
}
|
||
|
return
|
||
|
}
|
||
|
|
||
|
m := new(Message)
|
||
|
if err := json.NewDecoder(bytes.NewReader(rtcMessage.Data)).Decode(m); err != nil {
|
||
|
conn.triggerError(ErrProtocolViolation)
|
||
|
return
|
||
|
}
|
||
|
|
||
|
conn.triggerSendaroundMessage(m)
|
||
|
})
|
||
|
|
||
|
conn.dataChannel.OnClose(func() {
|
||
|
select {
|
||
|
case conn.closeWaitChannel <- nil:
|
||
|
default:
|
||
|
}
|
||
|
|
||
|
conn.triggerClose()
|
||
|
})
|
||
|
}
|
||
|
|
||
|
func (conn *DataChannel) AbortRawDataTransmission(err error) {
|
||
|
conn.rawDataLock.Lock()
|
||
|
defer conn.rawDataLock.Unlock()
|
||
|
|
||
|
if conn.rawDataRest > 0 {
|
||
|
conn.rawDataRest = 0
|
||
|
}
|
||
|
|
||
|
if conn.rawDataWriter != nil {
|
||
|
conn.rawDataWriter.CloseWithError(err)
|
||
|
conn.rawDataWriter = nil
|
||
|
}
|
||
|
|
||
|
if conn.rawDataReader != nil {
|
||
|
conn.rawDataReader = nil
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (conn *DataChannel) ExpectRawData(expectedLength uint64) {
|
||
|
conn.rawDataLock.Lock()
|
||
|
defer conn.rawDataLock.Unlock()
|
||
|
|
||
|
conn.rawDataReader, conn.rawDataWriter = io.Pipe()
|
||
|
conn.rawDataReaderBuffered = bufio.NewReaderSize(conn.rawDataReader, 16*1024*1024)
|
||
|
conn.rawDataRest = expectedLength
|
||
|
}
|
||
|
|
||
|
func (conn *DataChannel) SendMessage(msg *Message) (err error) {
|
||
|
if msg == nil {
|
||
|
err = errors.New("msg must not be nil")
|
||
|
return
|
||
|
}
|
||
|
|
||
|
buf := new(bytes.Buffer)
|
||
|
if err = json.NewEncoder(buf).Encode(msg); err != nil {
|
||
|
return
|
||
|
}
|
||
|
err = conn.dataChannel.SendText(buf.String())
|
||
|
if err != nil {
|
||
|
panic(err)
|
||
|
}
|
||
|
return
|
||
|
}
|
||
|
|
||
|
func (conn *DataChannel) Read(p []byte) (n int, err error) {
|
||
|
if conn.rawDataReaderBuffered == nil {
|
||
|
err = errors.New("Unexpected raw data read")
|
||
|
return
|
||
|
}
|
||
|
|
||
|
n, err = conn.rawDataReaderBuffered.Read(p)
|
||
|
return
|
||
|
}
|
||
|
|
||
|
func (conn *DataChannel) Write(p []byte) (n int, err error) {
|
||
|
err = conn.dataChannel.Send(p)
|
||
|
if err == nil {
|
||
|
n = len(p)
|
||
|
}
|
||
|
|
||
|
// Do not queue too much to avoid memory flood, wait for buffer to be empty enough
|
||
|
if conn.dataChannel.BufferedAmount() > conn.dataChannel.BufferedAmountLowThreshold()*10 {
|
||
|
<-conn.bufferWaitChannel
|
||
|
}
|
||
|
return
|
||
|
}
|