codename-sendaround/internal/data_channel.go

223 lines
4.6 KiB
Go
Raw Normal View History

2019-07-11 12:00:45 +00:00
package internal
import (
"bufio"
"bytes"
"encoding/json"
"errors"
"io"
"sync"
"github.com/pion/webrtc/v2"
)
type DataChannel struct {
dataChannel *webrtc.DataChannel
rawDataLock sync.RWMutex
rawDataWriter *io.PipeWriter
rawDataReaderBuffered *bufio.Reader
rawDataReader *io.PipeReader
rawDataRest uint64
bufferWaitChannel chan interface{}
closeWaitChannel chan interface{}
onClose func()
onOpen func()
onSendaroundMessage func(*Message)
onError func(error)
}
func NewDataChannel(dataChannel *webrtc.DataChannel) *DataChannel {
dc := &DataChannel{
dataChannel: dataChannel,
}
dc.init()
return dc
}
func (conn *DataChannel) OnClose(f func()) {
conn.onClose = f
}
func (conn *DataChannel) OnOpen(f func()) {
conn.onOpen = f
}
func (conn *DataChannel) OnSendaroundMessage(f func(*Message)) {
conn.onSendaroundMessage = f
}
func (conn *DataChannel) OnError(f func(error)) {
conn.onError = f
}
func (conn *DataChannel) Close() {
// already closed?
if conn.dataChannel.ReadyState() == webrtc.DataChannelStateClosing ||
conn.dataChannel.ReadyState() == webrtc.DataChannelStateClosed {
return
}
conn.AbortRawDataTransmission(io.EOF)
conn.dataChannel.Close()
<-conn.closeWaitChannel
}
func (conn *DataChannel) triggerOpen() {
if conn.onOpen != nil {
conn.onOpen()
}
}
func (conn *DataChannel) triggerClose() {
if conn.onClose != nil {
conn.onClose()
}
}
func (conn *DataChannel) triggerSendaroundMessage(msg *Message) {
if conn.onSendaroundMessage != nil {
conn.onSendaroundMessage(msg)
}
}
func (conn *DataChannel) triggerError(err error) {
if conn.onError != nil {
conn.onError(err)
}
conn.dataChannel.Close()
}
func (conn *DataChannel) init() {
conn.bufferWaitChannel = make(chan interface{}, 1)
conn.closeWaitChannel = make(chan interface{}, 1)
conn.dataChannel.OnOpen(func() {
conn.triggerOpen()
})
conn.dataChannel.SetBufferedAmountLowThreshold(1 * 1024 * 1024)
conn.dataChannel.OnBufferedAmountLow(func() {
select {
case conn.bufferWaitChannel <- nil:
default:
}
})
conn.dataChannel.OnMessage(func(rtcMessage webrtc.DataChannelMessage) {
if !rtcMessage.IsString {
conn.rawDataLock.RLock()
defer conn.rawDataLock.RUnlock()
if conn.rawDataRest == 0 {
conn.triggerError(ErrUnexpectedRawDataTransmission)
return
}
rawDataSize := uint64(len(rtcMessage.Data))
if conn.rawDataRest < rawDataSize {
conn.triggerError(ErrUnexpectedOverlongRawData)
return
}
_, err := conn.rawDataWriter.Write(rtcMessage.Data)
if err != nil {
conn.triggerError(err)
}
conn.rawDataRest -= rawDataSize
if conn.rawDataRest == 0 {
conn.rawDataWriter.CloseWithError(nil)
conn.rawDataWriter = nil
}
return
}
m := new(Message)
if err := json.NewDecoder(bytes.NewReader(rtcMessage.Data)).Decode(m); err != nil {
conn.triggerError(ErrProtocolViolation)
return
}
conn.triggerSendaroundMessage(m)
})
conn.dataChannel.OnClose(func() {
select {
case conn.closeWaitChannel <- nil:
default:
}
conn.triggerClose()
})
}
func (conn *DataChannel) AbortRawDataTransmission(err error) {
conn.rawDataLock.Lock()
defer conn.rawDataLock.Unlock()
if conn.rawDataRest > 0 {
conn.rawDataRest = 0
}
if conn.rawDataWriter != nil {
conn.rawDataWriter.CloseWithError(err)
conn.rawDataWriter = nil
}
if conn.rawDataReader != nil {
conn.rawDataReader = nil
}
}
func (conn *DataChannel) ExpectRawData(expectedLength uint64) {
conn.rawDataLock.Lock()
defer conn.rawDataLock.Unlock()
conn.rawDataReader, conn.rawDataWriter = io.Pipe()
conn.rawDataReaderBuffered = bufio.NewReaderSize(conn.rawDataReader, 16*1024*1024)
conn.rawDataRest = expectedLength
}
func (conn *DataChannel) SendMessage(msg *Message) (err error) {
if msg == nil {
err = errors.New("msg must not be nil")
return
}
buf := new(bytes.Buffer)
if err = json.NewEncoder(buf).Encode(msg); err != nil {
return
}
err = conn.dataChannel.SendText(buf.String())
if err != nil {
panic(err)
}
return
}
func (conn *DataChannel) Read(p []byte) (n int, err error) {
if conn.rawDataReaderBuffered == nil {
err = errors.New("Unexpected raw data read")
return
}
n, err = conn.rawDataReaderBuffered.Read(p)
return
}
func (conn *DataChannel) Write(p []byte) (n int, err error) {
err = conn.dataChannel.Send(p)
if err == nil {
n = len(p)
}
// Do not queue too much to avoid memory flood, wait for buffer to be empty enough
if conn.dataChannel.BufferedAmount() > conn.dataChannel.BufferedAmountLowThreshold()*10 {
<-conn.bufferWaitChannel
}
return
}