/*
* @Author: xiangcai
* @Date: 2021-11-02 19:52:51
* @LastEditors: xiangcai
* @LastEditTime: 2021-11-03 15:40:23
* @Description: file content
*/
package main
import (
"encoding/json"
"errors"
"flag"
"fmt"
"io"
"io/ioutil"
"net/http"
"net/url"
"os"
"os/signal"
"path"
"runtime"
"strconv"
"strings"
"sync"
"sync/atomic"
"syscall"
"time"
"github.com/sgs921107/gcommon"
"github.com/sgs921107/glogging"
)
var (
zapLogging = glogging.NewZapLogging(glogging.Options{})
zapLogger = zapLogging.GetSugaredLogger()
barInterval = time.Millisecond * 100
)
type Headers map[string]string
type ConcurrentDownloader struct {
// 请求头
headers Headers
// 并发数
Concurrent int
client http.Client
// 下载链接
URI string
filePath string
logger *glogging.ZapSugaredLogger
wg *sync.WaitGroup
downloadedSize int64
size int64
exit bool
}
func (c ConcurrentDownloader) SendReq(
method string,
body io.Reader,
headers *Headers,
) (*http.Response, error) {
request, err := http.NewRequest(method, c.URI, body)
for key, value := range c.headers {
request.Header.Add(key, value)
}
if headers != nil {
for key, value := range *headers {
request.Header.Add(key, value)
}
}
if err != nil {
c.logger.Errorw("Create Request Failed",
"errMsg", err.Error(),
)
return nil, err
}
return c.client.Do(request)
}
func (c *ConcurrentDownloader) SubDownload(num, start, end int) error {
defer c.wg.Done()
c.logger.Infof("Run Task %d, Range %d - %d", num, start, end)
subFilePath := fmt.Sprintf("%s.%d", c.filePath, num)
// 子文件已经存在
fileInfo, err := os.Stat(subFilePath)
if !os.IsNotExist(err) {
fileSize := fileInfo.Size()
atomic.AddInt64(&c.downloadedSize, fileSize)
// 已经下载完成
if int(fileSize) >= end-start {
c.logger.Infof("Task %d Already Downloaded", num)
return nil
}
start += int(fileSize)
}
headers := &Headers{
"Range": fmt.Sprintf("bytes=%d-%d", start, end),
}
resp, err := c.SendReq("GET", nil, headers)
if err != nil {
c.logger.Errorw("Reqeust Error",
"taskNum", num,
)
return err
}
defer resp.Body.Close()
file, err := os.OpenFile(subFilePath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644)
if err != nil {
c.logger.Errorw("Open SubFile Error",
"filename", subFilePath,
"errMsg", err.Error(),
)
return err
}
defer file.Close()
// 流失读取
for !c.exit {
buffer := make([]byte, 1024)
n, err := resp.Body.Read(buffer)
if n != 0 {
atomic.AddInt64(&c.downloadedSize, int64(n))
file.Write(buffer)
}
if err == io.EOF {
break
} else if err != nil {
c.logger.Errorw("Read Resp Body Error",
"errMsg", err.Error(),
)
return err
}
}
c.logger.Infof("Task %d Done", num)
return nil
}
// MergeSubFiles 合并所有子文件
func (c ConcurrentDownloader) MergeSubFiles() error {
file, err := os.OpenFile(c.filePath, os.O_CREATE|os.O_WRONLY, 0644)
if err != nil {
return err
}
defer file.Close()
for i := 0; i < c.Concurrent; i++ {
subFile, err := os.Open(fmt.Sprintf("%s.%d", c.filePath, i))
if err != nil {
return err
}
defer subFile.Close()
if buffer, err := ioutil.ReadAll(subFile); err != nil {
return err
} else {
file.Write(buffer)
}
os.Remove(subFile.Name())
}
return nil
}
// ProgressBar 进度条
func (c *ConcurrentDownloader) ProgressBar() {
for !c.exit {
downloadedSize := atomic.LoadInt64(&c.downloadedSize)
percentagePoint := downloadedSize * 100 / c.size
var downloadedBar = make([]string, percentagePoint/2)
var remainingBar = make([]string, (100-percentagePoint)/2)
for i := range downloadedBar {
downloadedBar[i] = "▋"
}
for i := range remainingBar {
remainingBar[i] = "."
}
fmt.Print("\r")
fmt.Printf(
"Progress: %d%% |%s%s|",
percentagePoint,
strings.Join(downloadedBar, ""),
strings.Join(remainingBar, ""),
)
time.Sleep(barInterval)
}
fmt.Println()
}
func (c *ConcurrentDownloader) singalHandler() {
//创建监听退出chan
ch := make(chan os.Signal, 1)
//监听指定信号 ctrl+c kill
signal.Notify(
ch,
syscall.SIGHUP,
syscall.SIGINT,
syscall.SIGTERM,
syscall.SIGQUIT,
// syscall.SIGUSR1,
// syscall.SIGUSR2,
)
for sig := range ch {
logger := c.logger.With(
"signal", sig,
)
switch sig {
case syscall.SIGHUP, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT:
c.exit = true
logger.Warnw("Recv Exit Signal")
default:
logger.Infow("Recv A Signal")
}
}
}
// Download 执行下载任务
func (c *ConcurrentDownloader) Download() error {
go c.singalHandler()
c.logger.Infow("start downlaod",
"url", c.URI,
)
size, err := c.GetTargetSize()
if err != nil {
c.logger.Error("Get Content Length Error",
"errMsg", err.Error(),
)
return err
}
c.size = int64(size)
fileInfo, _ := os.Stat(c.filePath)
if fileInfo != nil && int(fileInfo.Size()) >= size {
c.logger.Infow("Already Downloaded",
"url", c.URI,
)
return nil
}
c.logger.Infof("file size: %.2fM", float64(size)/(1024*1024))
// 每个子任务下载的大小
subSize := size/c.Concurrent + 1
for i := 0; i < c.Concurrent; i++ {
start, end := subSize*i, subSize*(i+1)-1
if end > size {
end = size
}
c.wg.Add(1)
go c.SubDownload(i, start, end)
}
go c.ProgressBar()
c.wg.Wait()
time.Sleep(barInterval)
c.exit = true
downloadedSize := atomic.LoadInt64(&c.downloadedSize)
if int(downloadedSize) < size {
c.logger.Errorw("DownloadedSizeError",
"ExpectSize", size,
"DownloadedSize", downloadedSize,
)
return err
}
// 合并文件
if c.MergeSubFiles() != nil {
c.logger.Errorw("ErrMergeSubFiles",
"errMsg", err.Error(),
)
return err
}
return nil
}
// GetTargetSize 获取下载目标的大小
func (c ConcurrentDownloader) GetTargetSize() (int, error) {
resp, err := c.SendReq("HEAD", nil, nil)
if err != nil {
return 0, err
}
defer resp.Body.Close()
size, err := strconv.Atoi(resp.Header.Get("Content-Length"))
if err != nil {
return 0, err
}
return size, nil
}
// NewDownloader 实例化downloader
func NewDownloader(
uri, saveDir string,
concurrent int,
timeout time.Duration,
headers Headers,
) (*ConcurrentDownloader, error) {
realURI, err := url.Parse(uri)
if err != nil {
return nil, err
}
if !gcommon.PathIsExist(saveDir) {
return nil, os.ErrNotExist
}
_, filename := path.Split(realURI.Path)
if filename == "" {
return nil, errors.New("INVALID RESOURCE")
}
downlaoder := &ConcurrentDownloader{
URI: realURI.String(),
Concurrent: concurrent,
logger: zapLogger,
wg: &sync.WaitGroup{},
filePath: path.Join(saveDir, filename),
headers: headers,
}
downlaoder.client = http.Client{
Timeout: timeout,
}
return downlaoder, nil
}
func init() {
// 设置最大可用的cpu数量
runtime.GOMAXPROCS(runtime.NumCPU())
}
func main() {
var uri, saveDir, headers string
var concurrent, timeout int
flag.StringVar(&uri, "uri", "", "download link, required")
flag.StringVar(&saveDir, "dir", ".", "download file to this dir")
flag.IntVar(&concurrent, "n", 5, "concurrent num")
flag.IntVar(&timeout, "t", 0, "download timeout, unit: Minute")
flag.StringVar(&headers, "headers", "{}", "Request Headers, formater: json")
flag.Parse()
if uri == "" {
zapLogger.Error("Param URI Is Required")
flag.CommandLine.Usage()
return
}
reqHeaders := make(Headers)
if json.Unmarshal([]byte(headers), &reqHeaders) != nil {
zapLogger.Error("Invalid Headers")
return
}
downlaoder, err := NewDownloader(
uri,
saveDir,
concurrent,
time.Minute*time.Duration(timeout),
reqHeaders,
)
if err != nil {
zapLogger.Errorw("Instantiation Downloader Failed",
"errMsg", err.Error(),
)
return
}
downlaoder.Download()
}
example cmd: go run filename.go -uri=http://www.example.com/test.mp4