自定义升级包
自定义升级包
package main
import (
"bytes"
"flag"
"fmt"
"io/ioutil"
"log"
"os"
"path"
"time"
)
const (
FLAG = "U&PK"
)
func main() {
file := flag.String("file", "", "升级包原始zip文件, eg. kms-update-170191.zip")
flag.Parse()
if *file == "" {
flag.Usage()
fmt.Printf("\n打包失败: %v\n", "参数错误")
return
}
nfile, err := GenUpPkgFile(*file)
if err != nil {
fmt.Printf("\n打包失败: %v\n", err)
return
}
fmt.Printf("---------------[升级包打包完成]-----------------------\n")
fmt.Printf("原始文件 : %v\n", *file)
fmt.Printf("升级包文件: %v\n", nfile)
}
func RecoverUpPkgFile(filename string) error {
b, err := checkUpPkgFile(filename)
if err != nil {
return err
}
if !b {
return fmt.Errorf("file is invalid")
}
return recoverFileHeader(filename)
}
func checkUpPkgFile(filename string) (bool, error) {
f, err := os.OpenFile(filename, os.O_RDWR, os.ModePerm)
if err != nil {
log.Fatal(err)
return false, err
}
defer f.Close()
fs, err := f.Stat()
if err != nil {
log.Fatal(err)
return false, err
}
log.Printf("原始 文件名: %v, 文件大小: %v", filename, fs.Size())
buf := make([]byte, 4)
n, err := f.Read(buf)
if err != nil {
return false, err
}
log.Printf("读取文件头 len=%v, src=%v", n, buf)
res := bytes.Compare([]byte(FLAG), buf)
if res == 0 {
return true, nil
}
return false, nil
}
func GenUpPkgFile(filename string) (string, error) {
// new file name .zip -> .up
fileNameWithSuffix := path.Base(filename)
fileExt := path.Ext(fileNameWithSuffix)
if fileExt != ".zip" {
return "", fmt.Errorf("文件格式不支持, %v", fileExt)
}
newfileWithoutSuffix := filename[:len(filename)-len(fileExt)]
newfile := newfileWithoutSuffix + ".up"
err := copyFile(filename, newfile)
if err != nil {
return "", fmt.Errorf("文件拷贝失败, %v", err)
}
err = updateFileHeader(newfile)
if err != nil {
return "", fmt.Errorf("打包文件失败, %v", err)
}
time.Sleep(20 * time.Second)
return newfile, nil
}
func copyFile(filename, newfile string) error {
input, err := ioutil.ReadFile(filename)
if err != nil {
return err
}
err = ioutil.WriteFile(newfile, input, 0644)
if err != nil {
return err
}
return nil
}
func updateFileHeader(filename string) error {
f, err := os.OpenFile(filename, os.O_RDWR, 0644)
if err != nil {
return err
}
defer f.Close()
if _, err := f.Seek(0, 0); err != nil {
return err
}
flag := FLAG
bytFlag := []byte(flag)
if _, err := f.WriteAt(bytFlag, 0); err != nil {
return err
}
return nil
}
func recoverFileHeader(filename string) error {
f, err := os.OpenFile(filename, os.O_RDWR, 0644)
if err != nil {
panic(err)
}
defer f.Close()
if _, err := f.Seek(0, 0); err != nil {
panic(err)
}
if _, err := f.WriteAt([]byte{0x50, 0x4B, 0x1, 0x2}, 0); err != nil {
panic(err)
}
return nil
}
func printPkgHeader(filename string) error {
f, err := os.OpenFile(filename, os.O_RDWR, os.ModePerm)
if err != nil {
log.Fatal(err)
return err
}
defer f.Close()
fs, err := f.Stat()
if err != nil {
log.Fatal(err)
return err
}
log.Printf("原始 文件名: %v, 文件大小: %v", filename, fs.Size())
buf := make([]byte, 100)
n, err := f.Read(buf)
if err != nil {
return err
}
log.Printf("读取文件头 len=%v, src=%v", n, buf)
return nil
}