自定义升级包

自定义升级包

package main

import (
	"bytes"
	"flag"
	"fmt"
	"io/ioutil"
	"log"
	"os"
	"path"
	"time"
)

const (
	FLAG = "U&PK"
)

func main() {
	file := flag.String("file", "", "升级包原始zip文件, eg. kms-update-170191.zip")
	flag.Parse()
	if *file == "" {
		flag.Usage()
		fmt.Printf("\n打包失败: %v\n", "参数错误")
		return
	}
	nfile, err := GenUpPkgFile(*file)
	if err != nil {
		fmt.Printf("\n打包失败: %v\n", err)
		return
	}
	fmt.Printf("---------------[升级包打包完成]-----------------------\n")
	fmt.Printf("原始文件  : %v\n", *file)
	fmt.Printf("升级包文件: %v\n", nfile)
}

func RecoverUpPkgFile(filename string) error {
	b, err := checkUpPkgFile(filename)
	if err != nil {
		return err
	}
	if !b {
		return fmt.Errorf("file is invalid")
	}
	return recoverFileHeader(filename)
}

func checkUpPkgFile(filename string) (bool, error) {
	f, err := os.OpenFile(filename, os.O_RDWR, os.ModePerm)
	if err != nil {
		log.Fatal(err)
		return false, err
	}
	defer f.Close()

	fs, err := f.Stat()
	if err != nil {
		log.Fatal(err)
		return false, err
	}
	log.Printf("原始 文件名: %v, 文件大小: %v", filename, fs.Size())

	buf := make([]byte, 4)
	n, err := f.Read(buf)
	if err != nil {
		return false, err
	}
	log.Printf("读取文件头 len=%v, src=%v", n, buf)

	res := bytes.Compare([]byte(FLAG), buf)
	if res == 0 {
		return true, nil
	}
	return false, nil
}

func GenUpPkgFile(filename string) (string, error) {
	// new file name      .zip -> .up
	fileNameWithSuffix := path.Base(filename)
	fileExt := path.Ext(fileNameWithSuffix)
	if fileExt != ".zip" {
		return "", fmt.Errorf("文件格式不支持, %v", fileExt)
	}
	newfileWithoutSuffix := filename[:len(filename)-len(fileExt)]
	newfile := newfileWithoutSuffix + ".up"

	err := copyFile(filename, newfile)
	if err != nil {
		return "", fmt.Errorf("文件拷贝失败, %v", err)
	}
	err = updateFileHeader(newfile)
	if err != nil {
		return "", fmt.Errorf("打包文件失败, %v", err)
	}
	time.Sleep(20 * time.Second)
	return newfile, nil
}

func copyFile(filename, newfile string) error {
	input, err := ioutil.ReadFile(filename)
	if err != nil {
		return err
	}
	err = ioutil.WriteFile(newfile, input, 0644)
	if err != nil {
		return err
	}
	return nil
}

func updateFileHeader(filename string) error {
	f, err := os.OpenFile(filename, os.O_RDWR, 0644)
	if err != nil {
		return err
	}
	defer f.Close()

	if _, err := f.Seek(0, 0); err != nil {
		return err
	}

	flag := FLAG
	bytFlag := []byte(flag)
	if _, err := f.WriteAt(bytFlag, 0); err != nil {
		return err
	}
	return nil
}

func recoverFileHeader(filename string) error {
	f, err := os.OpenFile(filename, os.O_RDWR, 0644)
	if err != nil {
		panic(err)
	}
	defer f.Close()

	if _, err := f.Seek(0, 0); err != nil {
		panic(err)
	}

	if _, err := f.WriteAt([]byte{0x50, 0x4B, 0x1, 0x2}, 0); err != nil {
		panic(err)
	}
	return nil
}

func printPkgHeader(filename string) error {
	f, err := os.OpenFile(filename, os.O_RDWR, os.ModePerm)
	if err != nil {
		log.Fatal(err)
		return err
	}
	defer f.Close()

	fs, err := f.Stat()
	if err != nil {
		log.Fatal(err)
		return err
	}
	log.Printf("原始 文件名: %v, 文件大小: %v", filename, fs.Size())

	buf := make([]byte, 100)
	n, err := f.Read(buf)
	if err != nil {
		return err
	}
	log.Printf("读取文件头 len=%v, src=%v", n, buf)

	return nil
}
posted @ 2022-07-14 11:28  jiftle  阅读(59)  评论(0编辑  收藏  举报