通过AWS STS临时授权凭证分片上传文件

1.《获取STS临时授权凭证》

2.《通过STS Token分片上传文件》

3.《文件预签名URL》

一、相关文档

1.AWS分片上传文档:https://docs.aws.amazon.com/zh_cn/AmazonS3/latest/userguide/mpuoverview.html

2.获取AWS STS临时授权凭证,go示例

二、GO示例

package main

import (
    "context"
    "errors"
    "fmt"
    "github.com/aws/aws-sdk-go-v2/aws"
    "github.com/aws/aws-sdk-go-v2/config"
    "github.com/aws/aws-sdk-go-v2/credentials"
    "github.com/aws/aws-sdk-go-v2/service/s3"
    "github.com/aws/aws-sdk-go-v2/service/s3/types"
    "github.com/aws/aws-sdk-go-v2/service/sts"
    "github.com/cubewise-code/go-mime"
    "io"
    "os"
    "path/filepath"
    "strconv"
    "strings"
    "time"
)

var fileChunkSize int64 = 1024 * 1024 * 100 // 建议设置为50-100MB

func main() {
    var bucketName = "{bucket名称}"
    var authPaths = []string{"uid", "output"} //uid可替换用户真实uid对应的目录,可支持多层级目录
    var expire int32 = 3600 //STS token有效期
    cfg := &StoreClientConf{
        RoleArn:         "{roleArn}",
        Region:          "{bucket region}",
        AccessKeyID:     "{bucket ak}",
        AccessKeySecret: "{bucket sk}",
    }
    client := NewAwsClient(cfg)
    // 1.获取STS授权凭证
    stsInfo, err := client.GetStsCredentials(context.Background(), bucketName, authPaths, expire)
    if err != nil {
        fmt.Println("client.GetStsCredentials err: " + err.Error())
        return
    }
    fmt.Println("sts ak: " + stsInfo.AccessKeyId)
    fmt.Println("sts sk: " + stsInfo.AccessSecret)
    fmt.Println("sts token: " + stsInfo.SecurityToken)

    // 2.通过STS token上传文件
    var objectKey = "output/hls/20240425-people_low.m3u8" //S3 bucket objectKey
    var fileDir = "/vagrant/20240425-people_low.m3u8" //本地文件目录
    err = client.UploadByToken(context.Background(), stsInfo, fileDir, bucketName, objectKey)
    if err != nil {
        fmt.Println("client.UploadByToken err: " + err.Error())
    }
    fmt.Println("client.UploadByToken success")
}

type AwsClient struct {
    roleArn         string
    region          string
    accessKeyID     string
    accessKeySecret string
}

type StoreClientConf struct {
    RoleArn         string
    Region          string
    AccessKeyID     string
    AccessKeySecret string
}

type StsCredentials struct {
    AccessKeyId   string
    AccessSecret  string
    SecurityToken string
    ExpireTime    int64
}

func NewAwsClient(cfg *StoreClientConf) *AwsClient {
    return &AwsClient{
        roleArn:         cfg.RoleArn,
        region:          cfg.Region,
        accessKeyID:     cfg.AccessKeyID,
        accessKeySecret: cfg.AccessKeySecret,
    }
}

func (s *AwsClient) loadConfig(ctx context.Context) (aws.Config, error) {
    cfg, err := config.LoadDefaultConfig(ctx,
        config.WithRegion(s.region),
        config.WithCredentialsProvider(credentials.StaticCredentialsProvider{
            Value: aws.Credentials{
                AccessKeyID: s.accessKeyID, SecretAccessKey: s.accessKeySecret, SessionToken: "",
                Source: "",
            },
        }),
    )
    if err != nil {
        fmt.Println("awsClient LoadDefaultConfig err:" + err.Error())
        return aws.Config{}, errors.New("awsClient LoadDefaultConfig err")
    }

    return cfg, nil
}

func (s *AwsClient) loadConfigByToken(ctx context.Context, ak, sk, token string) (aws.Config, error) {
    cfg, err := config.LoadDefaultConfig(ctx,
        config.WithRegion(s.region),
        config.WithCredentialsProvider(credentials.StaticCredentialsProvider{
            Value: aws.Credentials{
                AccessKeyID: ak, SecretAccessKey: sk, SessionToken: token,
                Source: "",
            },
        }),
    )
    if err != nil {
        fmt.Println("load config error: " + err.Error())
        return aws.Config{}, errors.New("load config error")
    }

    return cfg, nil
}

// 定义自己想要的policy
func (s *AwsClient) authPolicy(ctx context.Context, bucket string, authPaths []string) string {
    var resource []string
    for _, v := range authPaths {
        path := strings.TrimRight(v, "/") //去除最后一个/
        resource = append(resource, `"arn:aws:s3:::` + bucket+ `/` + path + `/*"`)
    }
    policy := `{
    "Version": "2012-10-17",
    "Statement": [
        {
            "Action": [
                "s3:GetObject",
                "s3:GetObjectAttributes",
                "s3:GetObjectTagging",
                "s3:PutObject",
                "s3:PutObjectTagging",
                "s3:UploadPart"
            ],
            "Effect": "Allow",
            "Resource": [` + strings.Join(resource, ",") + `]
        }
    ]
}`
    return policy
}

func (s *AwsClient) GetStsCredentials(ctx context.Context, bucket string, authPaths []string, expired int32) (*StsCredentials, error) {
    // 1.拼装授权策略
    policy := s.authPolicy(ctx, bucket, authPaths)

    // 2.初始化client
    cfg, err := s.loadConfig(ctx)
    if err != nil {
        return nil, err
    }
    client := sts.NewFromConfig(cfg)

    // 3.调用s3接口,获取sts token
    roleSessionName := "s3bucket" + strconv.FormatInt(time.Now().Unix(), 10) //需要按用户的维度去修改
    input := &sts.AssumeRoleInput{
        RoleArn:         &s.roleArn,
        RoleSessionName: &roleSessionName,
        DurationSeconds: &expired,
        Policy: aws.String(policy),
    }
    resp, err := client.AssumeRole(ctx, input)
    if err != nil {
        fmt.Println("GetStsCredentials client.AssumeRole err:" + err.Error())
        return nil, err
    }
    if resp == nil {
        fmt.Println("GetStsCredentials response is nil")
        return nil, err
    }
    var expire int64
    if resp.Credentials != nil && resp.Credentials.Expiration != nil {
        expire = resp.Credentials.Expiration.Unix()
    }
    return &StsCredentials{
        AccessKeyId:   *resp.Credentials.AccessKeyId,
        AccessSecret:  *resp.Credentials.SecretAccessKey,
        SecurityToken: *resp.Credentials.SessionToken,
        ExpireTime: expire,
    }, nil
}

type FileChunk struct {
    Number int   // Chunk number
    Offset int64 // Chunk offset
    Size   int64 // Chunk size.
}

type S3Part struct {
    Number int
    ETag *string
}

// 按大小分片
func SplitFileByPartSize(fileName string, chunkSize int64) ([]FileChunk, error) {
    if chunkSize <= 0 {
        return nil, errors.New("chunkSize invalid")
    }

    file, err := os.Open(fileName)
    if err != nil {
        return nil, err
    }
    defer file.Close()

    stat, err := file.Stat()
    if err != nil {
        return nil, err
    }
    var chunkN = stat.Size() / chunkSize
    if chunkN >= 10000 {
        return nil, errors.New("Too many parts, please increase part size")
    }

    var chunks []FileChunk
    var chunk = FileChunk{}
    for i := int64(0); i < chunkN; i++ {
        chunk.Number = int(i + 1)
        chunk.Offset = i * chunkSize
        chunk.Size = chunkSize
        chunks = append(chunks, chunk)
    }

    if stat.Size()%chunkSize > 0 {
        chunk.Number = len(chunks) + 1
        chunk.Offset = int64(len(chunks)) * chunkSize
        chunk.Size = stat.Size() % chunkSize
        chunks = append(chunks, chunk)
    }

    return chunks, nil
}

func (s *AwsClient) UploadByToken(ctx context.Context, stsInfo *StsCredentials, filePath, bucketName, objectKey string) error {
    // 1.初始化客户端
    cfg, err := s.loadConfigByToken(ctx, stsInfo.AccessKeyId, stsInfo.AccessSecret, stsInfo.SecurityToken)
    if err != nil {
        return err
    }
    s3Client := s3.NewFromConfig(cfg)

    // 2.计算文件分片
    chunks, err := SplitFileByPartSize(filePath, fileChunkSize)
    for _, v := range chunks {
        fmt.Println("chunk: ", v.Number, ", size:", v.Size, ", offset:", v.Offset)
    }

    // 3.获取文件操作句柄
    fd, err := os.Open(filePath)
    if err != nil {
        return err
    }
    defer fd.Close()

    contentType := gomime.TypeByExtension(filepath.Ext(fd.Name()))
    // 4.上传文件
    if len(chunks) <= 1 { // 单片,普通上传
        _, err = s3Client.PutObject(ctx, &s3.PutObjectInput{
            Bucket: aws.String(bucketName),
            Key:    aws.String(objectKey),
            Body:   fd,
            ContentType: aws.String(contentType),
        })
        if err != nil {
            fmt.Println("s3Client.PutObject err:", err.Error())
            return err
        }
        return nil
    }

    // 4.1创建分段上传
    imur, err := s3Client.CreateMultipartUpload(ctx, &s3.CreateMultipartUploadInput{Bucket: aws.String(bucketName), Key:aws.String(objectKey), ContentType: aws.String(contentType)})
    if err != nil {
        fmt.Println("s3Client.CreateMultipartUpload err:", err.Error())
        return err
    }

    // 4.2分段上传
    var parts []*S3Part
    for _, chunk := range chunks {
        n, err := fd.Seek(chunk.Offset, os.SEEK_SET)
        fmt.Println("n:", n, ", err: ", err)
        input := &s3.UploadPartInput{
            Bucket:               aws.String(bucketName),
            Key:                  aws.String(objectKey),
            PartNumber:           aws.Int32(int32(chunk.Number)),
            UploadId:             imur.UploadId,
            Body:                 &io.LimitedReader{R: fd, N: chunk.Size}, //按片读取
            ContentLength: aws.Int64(chunk.Size),
        }
        tmp, err := s3Client.UploadPart(ctx, input)
        if err != nil {
            fmt.Println("s3Client.UploadPart err:", err.Error())
            // 终止分片上传
            _, errs := s3Client.AbortMultipartUpload(ctx, &s3.AbortMultipartUploadInput{Bucket: aws.String(bucketName), Key: aws.String(objectKey), UploadId: imur.UploadId})
            if errs != nil {
                fmt.Println("s3Client.AbortMultipartUpload err:", errs.Error())
            }
            return err
        }
        parts = append(parts, &S3Part{Number: chunk.Number, ETag: tmp.ETag})
    }

    // 4.3完成分段上传
    var multipart []types.CompletedPart
    for _, v := range parts {
        multipart = append(multipart, types.CompletedPart{ETag: v.ETag, PartNumber: aws.Int32(int32(v.Number))})
    }
    mult := &types.CompletedMultipartUpload{Parts: multipart}
    cInput := &s3.CompleteMultipartUploadInput{
        Bucket:               aws.String(bucketName),
        Key:                  aws.String(objectKey),
        UploadId:             imur.UploadId,
        MultipartUpload:      mult,
    }
    _, err = s3Client.CompleteMultipartUpload(ctx, cInput)
    if err != nil {
        fmt.Println("s3Client.CompleteMultipartUpload err:", err.Error())
        return err
    }
    return nil
}

 

posted @ 2024-04-28 19:48  划水的猫  阅读(218)  评论(0编辑  收藏  举报