package sendaround import ( "io" "sync" "github.com/icedream/sendaround/internal" ) type Downloader struct { dataChannel *internal.DataChannel filesLock sync.RWMutex files map[string]RemoteFile err error stateLock sync.RWMutex state ConnectionState stateC chan ConnectionState CopyBufferSize int } func (client *Downloader) init() { if client.CopyBufferSize == 0 { client.CopyBufferSize = DefaultCopyBufferSize } client.stateC = make(chan ConnectionState, 1) client.files = map[string]RemoteFile{} client.changeState(ConnectionState{Type: WaitingForClient}) client.dataChannel.OnClose(func() { switch client.state.Type { case TransmittingFile: client.changeState(ConnectionState{Type: Failed, Error: ErrClientClosedConnection}) case Failed: case Disconnected: default: client.changeState(ConnectionState{Type: Disconnected}) } close(client.stateC) }) client.dataChannel.OnError(func(err error) { client.abortConnection(err) }) client.dataChannel.OnSendaroundMessage(func(msg *internal.Message) { switch { case msg.FileOfferMessage != nil: if err := client.addFile(&fileRemote{ fileName: msg.FileOfferMessage.FileName, length: msg.FileOfferMessage.Length, mimeType: msg.FileOfferMessage.MimeType, }); err != nil { client.abortConnection(err) } case msg.FileUnofferMessage != nil: client.removeFile(msg.FileUnofferMessage.FileName) case msg.SessionInitializedMessage != nil: client.changeState(ConnectionState{Type: Connected}) default: // Something's wrong with this message... } }) } func (client *Downloader) StateC() chan ConnectionState { return client.stateC } func (client *Downloader) Files() map[string]RemoteFile { client.filesLock.RLock() defer client.filesLock.RUnlock() retval := map[string]RemoteFile{} for filePath, f := range client.files { retval[filePath] = f } return retval } func (client *Downloader) RetrieveFile(filePath string) (r io.ReadCloser, err error) { client.stateLock.Lock() defer client.stateLock.Unlock() if client.state.Type != Connected { err = ErrInvalidState return } filePath = normalizeFilePath(filePath) state := ConnectionState{Type: TransmittingFile, CurrentFile: client.files[filePath]} client._unlocked_changeState(state) client.dataChannel.ExpectRawData(state.CurrentFile.Length()) client.dataChannel.SendMessage(&internal.Message{ FileTransferRequestMessage: &internal.FileTransferRequestMessage{ FileName: normalizeFilePath(filePath), }, }) r, w := io.Pipe() go func(w *io.PipeWriter) { var err error defer func() { if err != nil { w.CloseWithError(err) client.abortConnection(err) } else { w.CloseWithError(io.EOF) state.Type = Connected client.changeState(state) } }() var n int b := make([]byte, client.CopyBufferSize) for { n, err = client.dataChannel.Read(b) if err == io.EOF { err = nil break } _, err = w.Write(b[0:n]) if err != nil { return } state.TransmittedLength += uint64(n) client._unlocked_changeState(state) } }(w) return } func (client *Downloader) Close() error { client.stateLock.Lock() defer client.stateLock.Unlock() if client.state.Type == Failed || client.state.Type == Disconnected { return nil } if client.state.Type == TransmittingFile { return ErrInvalidState } client.dataChannel.Close() return nil } func (client *Downloader) abortConnection(err error) { client.stateLock.Lock() defer client.stateLock.Unlock() client.dataChannel.Close() state := client.state state.Type = Failed state.Error = err client.changeState(state) } func (client *Downloader) _unlocked_changeState(state ConnectionState) { client.state = state client.stateC <- state } func (client *Downloader) changeState(state ConnectionState) { client.stateLock.Lock() defer client.stateLock.Unlock() client._unlocked_changeState(state) } func (client *Downloader) addFile(fileToAdd RemoteFile) (err error) { client.filesLock.Lock() defer client.filesLock.Unlock() if _, exists := client.files[fileToAdd.FileName()]; exists { err = ErrFileAlreadyExists return } client.files[fileToAdd.FileName()] = fileToAdd return } func (client *Downloader) removeFile(filePath string) { client.filesLock.Lock() defer client.filesLock.Unlock() delete(client.files, normalizeFilePath(filePath)) }