// Copyright 2019 The go-ethereum Authors // This file is part of the go-ethereum library. // // The go-ethereum library is free software: you can redistribute it and/or modify // it under the terms of the GNU Lesser General Public License as published by // the Free Software Foundation, either version 3 of the License, or // (at your option) any later version. // // The go-ethereum library is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Lesser General Public License for more details. // // You should have received a copy of the GNU Lesser General Public License // along with the go-ethereum library. If not, see . package build import ( "bufio" "crypto/sha256" "encoding/hex" "fmt" "io" "io/ioutil" "log" "net/http" "os" "path/filepath" "strings" ) // ChecksumDB keeps file checksums. type ChecksumDB struct { allChecksums []string } // MustLoadChecksums loads a file containing checksums. func MustLoadChecksums(file string) *ChecksumDB { content, err := ioutil.ReadFile(file) if err != nil { log.Fatal("can't load checksum file: " + err.Error()) } return &ChecksumDB{strings.Split(string(content), "\n")} } // Verify checks whether the given file is valid according to the checksum database. func (db *ChecksumDB) Verify(path string) error { fd, err := os.Open(path) if err != nil { return err } defer fd.Close() h := sha256.New() if _, err := io.Copy(h, bufio.NewReader(fd)); err != nil { return err } fileHash := hex.EncodeToString(h.Sum(nil)) if !db.findHash(filepath.Base(path), fileHash) { return fmt.Errorf("invalid file hash %s", fileHash) } return nil } func (db *ChecksumDB) findHash(basename, hash string) bool { want := hash + " " + basename for _, line := range db.allChecksums { if strings.TrimSpace(line) == want { return true } } return false } // DownloadFile downloads a file and verifies its checksum. func (db *ChecksumDB) DownloadFile(url, dstPath string) error { if err := db.Verify(dstPath); err == nil { fmt.Printf("%s is up-to-date\n", dstPath) return nil } fmt.Printf("%s is stale\n", dstPath) fmt.Printf("downloading from %s\n", url) resp, err := http.Get(url) if err != nil || resp.StatusCode != http.StatusOK { return fmt.Errorf("download error: code %d, err %v", resp.StatusCode, err) } defer resp.Body.Close() if err := os.MkdirAll(filepath.Dir(dstPath), 0755); err != nil { return err } fd, err := os.OpenFile(dstPath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644) if err != nil { return err } dst := newDownloadWriter(fd, resp.ContentLength) _, err = io.Copy(dst, resp.Body) dst.Close() if err != nil { return err } return db.Verify(dstPath) } type downloadWriter struct { file *os.File dstBuf *bufio.Writer size int64 written int64 lastpct int64 } func newDownloadWriter(dst *os.File, size int64) *downloadWriter { return &downloadWriter{ file: dst, dstBuf: bufio.NewWriter(dst), size: size, } } func (w *downloadWriter) Write(buf []byte) (int, error) { n, err := w.dstBuf.Write(buf) // Report progress. w.written += int64(n) pct := w.written * 10 / w.size * 10 if pct != w.lastpct { if w.lastpct != 0 { fmt.Print("...") } fmt.Print(pct, "%") w.lastpct = pct } return n, err } func (w *downloadWriter) Close() error { if w.lastpct > 0 { fmt.Println() // Finish the progress line. } flushErr := w.dstBuf.Flush() closeErr := w.file.Close() if flushErr != nil { return flushErr } return closeErr }