通过AWS STS临时授权凭证分片上传文件
一、相关文档
1.AWS分片上传文档:https://docs.aws.amazon.com/zh_cn/AmazonS3/latest/userguide/mpuoverview.html
2.获取AWS STS临时授权凭证,go示例
二、GO示例
package main import ( "context" "errors" "fmt" "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/config" "github.com/aws/aws-sdk-go-v2/credentials" "github.com/aws/aws-sdk-go-v2/service/s3" "github.com/aws/aws-sdk-go-v2/service/s3/types" "github.com/aws/aws-sdk-go-v2/service/sts" "github.com/cubewise-code/go-mime" "io" "os" "path/filepath" "strconv" "strings" "time" ) var fileChunkSize int64 = 1024 * 1024 * 100 // 建议设置为50-100MB func main() { var bucketName = "{bucket名称}" var authPaths = []string{"uid", "output"} //uid可替换用户真实uid对应的目录,可支持多层级目录 var expire int32 = 3600 //STS token有效期 cfg := &StoreClientConf{ RoleArn: "{roleArn}", Region: "{bucket region}", AccessKeyID: "{bucket ak}", AccessKeySecret: "{bucket sk}", } client := NewAwsClient(cfg) // 1.获取STS授权凭证 stsInfo, err := client.GetStsCredentials(context.Background(), bucketName, authPaths, expire) if err != nil { fmt.Println("client.GetStsCredentials err: " + err.Error()) return } fmt.Println("sts ak: " + stsInfo.AccessKeyId) fmt.Println("sts sk: " + stsInfo.AccessSecret) fmt.Println("sts token: " + stsInfo.SecurityToken) // 2.通过STS token上传文件 var objectKey = "output/hls/20240425-people_low.m3u8" //S3 bucket objectKey var fileDir = "/vagrant/20240425-people_low.m3u8" //本地文件目录 err = client.UploadByToken(context.Background(), stsInfo, fileDir, bucketName, objectKey) if err != nil { fmt.Println("client.UploadByToken err: " + err.Error()) } fmt.Println("client.UploadByToken success") } type AwsClient struct { roleArn string region string accessKeyID string accessKeySecret string } type StoreClientConf struct { RoleArn string Region string AccessKeyID string AccessKeySecret string } type StsCredentials struct { AccessKeyId string AccessSecret string SecurityToken string ExpireTime int64 } func NewAwsClient(cfg *StoreClientConf) *AwsClient { return &AwsClient{ roleArn: cfg.RoleArn, region: cfg.Region, accessKeyID: cfg.AccessKeyID, accessKeySecret: cfg.AccessKeySecret, } } func (s *AwsClient) loadConfig(ctx context.Context) (aws.Config, error) { cfg, err := config.LoadDefaultConfig(ctx, config.WithRegion(s.region), config.WithCredentialsProvider(credentials.StaticCredentialsProvider{ Value: aws.Credentials{ AccessKeyID: s.accessKeyID, SecretAccessKey: s.accessKeySecret, SessionToken: "", Source: "", }, }), ) if err != nil { fmt.Println("awsClient LoadDefaultConfig err:" + err.Error()) return aws.Config{}, errors.New("awsClient LoadDefaultConfig err") } return cfg, nil } func (s *AwsClient) loadConfigByToken(ctx context.Context, ak, sk, token string) (aws.Config, error) { cfg, err := config.LoadDefaultConfig(ctx, config.WithRegion(s.region), config.WithCredentialsProvider(credentials.StaticCredentialsProvider{ Value: aws.Credentials{ AccessKeyID: ak, SecretAccessKey: sk, SessionToken: token, Source: "", }, }), ) if err != nil { fmt.Println("load config error: " + err.Error()) return aws.Config{}, errors.New("load config error") } return cfg, nil } // 定义自己想要的policy func (s *AwsClient) authPolicy(ctx context.Context, bucket string, authPaths []string) string { var resource []string for _, v := range authPaths { path := strings.TrimRight(v, "/") //去除最后一个/ resource = append(resource, `"arn:aws:s3:::` + bucket+ `/` + path + `/*"`) } policy := `{ "Version": "2012-10-17", "Statement": [ { "Action": [ "s3:GetObject", "s3:GetObjectAttributes", "s3:GetObjectTagging", "s3:PutObject", "s3:PutObjectTagging", "s3:UploadPart" ], "Effect": "Allow", "Resource": [` + strings.Join(resource, ",") + `] } ] }` return policy } func (s *AwsClient) GetStsCredentials(ctx context.Context, bucket string, authPaths []string, expired int32) (*StsCredentials, error) { // 1.拼装授权策略 policy := s.authPolicy(ctx, bucket, authPaths) // 2.初始化client cfg, err := s.loadConfig(ctx) if err != nil { return nil, err } client := sts.NewFromConfig(cfg) // 3.调用s3接口,获取sts token roleSessionName := "s3bucket" + strconv.FormatInt(time.Now().Unix(), 10) //需要按用户的维度去修改 input := &sts.AssumeRoleInput{ RoleArn: &s.roleArn, RoleSessionName: &roleSessionName, DurationSeconds: &expired, Policy: aws.String(policy), } resp, err := client.AssumeRole(ctx, input) if err != nil { fmt.Println("GetStsCredentials client.AssumeRole err:" + err.Error()) return nil, err } if resp == nil { fmt.Println("GetStsCredentials response is nil") return nil, err } var expire int64 if resp.Credentials != nil && resp.Credentials.Expiration != nil { expire = resp.Credentials.Expiration.Unix() } return &StsCredentials{ AccessKeyId: *resp.Credentials.AccessKeyId, AccessSecret: *resp.Credentials.SecretAccessKey, SecurityToken: *resp.Credentials.SessionToken, ExpireTime: expire, }, nil } type FileChunk struct { Number int // Chunk number Offset int64 // Chunk offset Size int64 // Chunk size. } type S3Part struct { Number int ETag *string } // 按大小分片 func SplitFileByPartSize(fileName string, chunkSize int64) ([]FileChunk, error) { if chunkSize <= 0 { return nil, errors.New("chunkSize invalid") } file, err := os.Open(fileName) if err != nil { return nil, err } defer file.Close() stat, err := file.Stat() if err != nil { return nil, err } var chunkN = stat.Size() / chunkSize if chunkN >= 10000 { return nil, errors.New("Too many parts, please increase part size") } var chunks []FileChunk var chunk = FileChunk{} for i := int64(0); i < chunkN; i++ { chunk.Number = int(i + 1) chunk.Offset = i * chunkSize chunk.Size = chunkSize chunks = append(chunks, chunk) } if stat.Size()%chunkSize > 0 { chunk.Number = len(chunks) + 1 chunk.Offset = int64(len(chunks)) * chunkSize chunk.Size = stat.Size() % chunkSize chunks = append(chunks, chunk) } return chunks, nil } func (s *AwsClient) UploadByToken(ctx context.Context, stsInfo *StsCredentials, filePath, bucketName, objectKey string) error { // 1.初始化客户端 cfg, err := s.loadConfigByToken(ctx, stsInfo.AccessKeyId, stsInfo.AccessSecret, stsInfo.SecurityToken) if err != nil { return err } s3Client := s3.NewFromConfig(cfg) // 2.计算文件分片 chunks, err := SplitFileByPartSize(filePath, fileChunkSize) for _, v := range chunks { fmt.Println("chunk: ", v.Number, ", size:", v.Size, ", offset:", v.Offset) } // 3.获取文件操作句柄 fd, err := os.Open(filePath) if err != nil { return err } defer fd.Close() contentType := gomime.TypeByExtension(filepath.Ext(fd.Name())) // 4.上传文件 if len(chunks) <= 1 { // 单片,普通上传 _, err = s3Client.PutObject(ctx, &s3.PutObjectInput{ Bucket: aws.String(bucketName), Key: aws.String(objectKey), Body: fd, ContentType: aws.String(contentType), }) if err != nil { fmt.Println("s3Client.PutObject err:", err.Error()) return err } return nil } // 4.1创建分段上传 imur, err := s3Client.CreateMultipartUpload(ctx, &s3.CreateMultipartUploadInput{Bucket: aws.String(bucketName), Key:aws.String(objectKey), ContentType: aws.String(contentType)}) if err != nil { fmt.Println("s3Client.CreateMultipartUpload err:", err.Error()) return err } // 4.2分段上传 var parts []*S3Part for _, chunk := range chunks { n, err := fd.Seek(chunk.Offset, os.SEEK_SET) fmt.Println("n:", n, ", err: ", err) input := &s3.UploadPartInput{ Bucket: aws.String(bucketName), Key: aws.String(objectKey), PartNumber: aws.Int32(int32(chunk.Number)), UploadId: imur.UploadId, Body: &io.LimitedReader{R: fd, N: chunk.Size}, //按片读取 ContentLength: aws.Int64(chunk.Size), } tmp, err := s3Client.UploadPart(ctx, input) if err != nil { fmt.Println("s3Client.UploadPart err:", err.Error()) // 终止分片上传 _, errs := s3Client.AbortMultipartUpload(ctx, &s3.AbortMultipartUploadInput{Bucket: aws.String(bucketName), Key: aws.String(objectKey), UploadId: imur.UploadId}) if errs != nil { fmt.Println("s3Client.AbortMultipartUpload err:", errs.Error()) } return err } parts = append(parts, &S3Part{Number: chunk.Number, ETag: tmp.ETag}) } // 4.3完成分段上传 var multipart []types.CompletedPart for _, v := range parts { multipart = append(multipart, types.CompletedPart{ETag: v.ETag, PartNumber: aws.Int32(int32(v.Number))}) } mult := &types.CompletedMultipartUpload{Parts: multipart} cInput := &s3.CompleteMultipartUploadInput{ Bucket: aws.String(bucketName), Key: aws.String(objectKey), UploadId: imur.UploadId, MultipartUpload: mult, } _, err = s3Client.CompleteMultipartUpload(ctx, cInput) if err != nil { fmt.Println("s3Client.CompleteMultipartUpload err:", err.Error()) return err } return nil }