在Golang中使用MongoDB的一对多关系与文档引用(附实例)

341 阅读3分钟

在这个例子中,我们将在 "博客 "域中处理 "标签 "和 "帖子"。你通常可以把标签作为一个数组字段存储在帖子文档中。然而,这将很快导致标签数据的重复。为了避免这种情况,你可以把标签引用作为一个数组字段存储在帖子文档里面。鉴于每个帖子的标签数量不多,增长有限,将标签引用存储在帖子文档内是可行的。这是一个典型的1-n关系,见下面的例子。

数据库内容

准备工作

use blog

db.createCollection("tags")
db.tags.createIndex({"uuid":1},{unique:true,name:"UQ_uuid"})
db.tags.createIndex({"name":1},{unique:true,name:"UQ_name"})

db.createCollection("posts")
db.posts.createIndex({"uuid":1},{unique:true,name:"UQ_uuid"})

标签数据

[
  {
    "_id": {
      "$oid": "604f463ed42b77e01e7f03db"
    },
    "uuid": "07131782-0090-40cf-b530-68e774669826",
    "name": "tech",
    "description": "Technology related stuff"
  },
  {
    "_id": {
      "$oid": "604f463ed42b77e01e7f03dc"
    },
    "uuid": "a1e57bc6-de13-432d-ad22-7a3f5ba07b80",
    "name": "sport",
    "description": "Any kind of sport"
  },
  {
    "_id": {
      "$oid": "604f463ed42b77e01e7f03dd"
    },
    "uuid": "4f489d03-e4c9-444a-ba09-076d0584c96e",
    "name": "travel",
    "description": "Travel or holiday related stuff"
  }
]

帖子数据

[
  {
    "_id": {
      "$oid": "604f4831601401a8de9dfe1a"
    },
    "uuid": "acda18b2-f5c9-4736-8201-d1dc68c59354",
    "subject": "Tech in football",
    "text": "VAR ruined football so far!",
    "created_at": {
      "$date": "2021-03-15T11:42:41.171Z"
    },
    "tags": [
      "07131782-0090-40cf-b530-68e774669826",
      "a1e57bc6-de13-432d-ad22-7a3f5ba07b80"
    ]
  },
  {
    "_id": {
      "$oid": "604f4891601401a8de9dfe1b"
    },
    "uuid": "368f4fba-b1d1-4c58-9682-74ceedd276e9",
    "subject": "Covid on holiday",
    "text": "Do not take your Covid on holiday with you :)",
    "created_at": {
      "$date": "2021-03-15T11:44:17.613Z"
    },
    "tags": [
      "4f489d03-e4c9-444a-ba09-076d0584c96e"
    ]
  },
  {
    "_id": {
      "$oid": "604f494e601401a8de9dfe1c"
    },
    "uuid": "ba87a063-7379-4b13-b7a5-d986f2bed43c",
    "subject": "Pasta or not to pasta?",
    "text": "Let's talk about what pasta is easiest to make.",
    "created_at": {
      "$date": "2021-03-15T11:47:26.607Z"
    },
    "tags": null
  }
]

储存

模型

package storage

import (
	"context"
	"time"
)

type PostStorer interface {
	Insert(ctx context.Context, post PostWrite) error
	Find(ctx context.Context, uuid string) (PostRead, error)
}

type PostWrite struct {
	UUID      string    `bson:"uuid"`
	Subject   string    `bson:"subject"`
	Text      string    `bson:"text"`
	CreatedAt time.Time `bson:"created_at"`
	TagUUIDs  []string  `bson:"tags"`
}

type PostRead struct {
	ID        string    `bson:"_id"`
	UUID      string    `bson:"uuid"`
	Subject   string    `bson:"subject"`
	Text      string    `bson:"text"`
	CreatedAt time.Time `bson:"created_at"`
	Tags      []Tag     `bson:"tags"`
}

type Tag struct {
	ID          string `bson:"_id"`
	UUID        string `bson:"uuid"`
	Name        string `bson:"name"`
	Description string `bson:"description"`
}

存储器

package mongodb

import (
	"context"
	"log"
	"time"

	"github.com/you/mongo/internal/pkg/domain"
	"github.com/you/mongo/internal/pkg/storage"
	"go.mongodb.org/mongo-driver/bson"
	"go.mongodb.org/mongo-driver/mongo"
)

var _ storage.PostStorer = PostStorage{}

type PostStorage struct {
	Database *mongo.Database
	Timeout  time.Duration
}

func (p PostStorage) Insert(ctx context.Context, post storage.PostWrite) error {
	ctx, cancel := context.WithTimeout(ctx, p.Timeout)
	defer cancel()

	if _, err := p.Database.Collection("posts").InsertOne(ctx, post); err != nil {
		log.Println(err)

		if er, ok := err.(mongo.WriteException); ok && er.WriteErrors[0].Code == 11000 {
			return domain.ErrConflict
		}

		return domain.ErrInternal
	}

	return nil
}

func (p PostStorage) Find(ctx context.Context, uuid string) (storage.PostRead, error) {
	ctx, cancel := context.WithTimeout(ctx, p.Timeout)
	defer cancel()

	// Tags are not sorted here!
	// qry := []bson.M{
	// 	{
	// 		"$match": bson.M{
	// 			"uuid": uuid,
	// 		},
	// 	},
	// 	{
	// 		"$lookup": bson.M{
	// 			"from":         "tags", // Child collection to join
	// 			"localField":   "tags", // Parent collection reference holding field
	// 			"foreignField": "uuid", // Child collection reference field
	// 			"as":           "tags", // Arbitrary field name to store result set
	// 		},
	// 	},
	// }

	// Tags are asc sorted here!
	qry := []bson.M{
		{
			"$match": bson.M{
				"uuid": uuid,
			},
		},
		{
			"$lookup": bson.M{
				// Define the tags collection for the join.
				"from": "tags",
				// Specify the variable to use in the pipeline stage.
				"let": bson.M{
					"tags": "$tags",
				},
				"pipeline": []bson.M{
					// Select only the relevant tags from the tags collection.
					// Otherwise all the tags are selected.
					{
						"$match": bson.M{
							"$expr": bson.M{
								"$in": []interface{}{
									"$uuid",
									"$$tags",
								},
							},
						},
					},
					// Sort tags by their name field in asc. -1 = desc
					{
						"$sort": bson.M{
							"name": 1,
						},
					},
				},
				// Use tags as the field name to match struct field.
				"as": "tags",
			},
		},
	}

	cur, err := p.Database.Collection("posts").Aggregate(ctx, qry)
	if err != nil {
		log.Println(err)

		return storage.PostRead{}, domain.ErrInternal
	}

	var pos []storage.PostRead

	if err := cur.All(context.Background(), &pos); err != nil {
		log.Println(err)

		return storage.PostRead{}, domain.ErrInternal
	}
	defer cur.Close(context.Background())

	if err := cur.Err(); err != nil {
		log.Println(err)

		return storage.PostRead{}, domain.ErrInternal
	}

	if len(pos) == 0 {
		return storage.PostRead{}, domain.ErrNotFound
	}

	return pos[0], nil
}

HTTP路由器

模型

package post

import "time"

// Request

type Create struct {
	Subject string   `json:"subject"`
	Text    string   `json:"text"`
	Tags    []string `json:"tags"`
}

// Response

type Response struct {
	ID        string    `json:"id"`
	UUID      string    `json:"uuid"`
	Subject   string    `json:"subject"`
	Text      string    `json:"text"`
	CreatedAt time.Time `json:"created_at"`
	Tags      []Tag     `json:"tags"`
}

type Tag struct {
	ID          string `json:"id"`
	UUID        string `json:"uuid"`
	Name        string `json:"name"`
	Description string `json:"description"`
}

控制器

正如你所看到的,这个文件做了很多事情,而且没有一个真正的请求验证。你应该根据你的需要来重构它。

package post

import (
	"encoding/json"
	"fmt"
	"net/http"
	"time"

	"github.com/you/mongo/internal/pkg/domain"
	"github.com/you/mongo/internal/pkg/storage"
	"github.com/google/uuid"
	"github.com/julienschmidt/httprouter"
)

type Controller struct {
	Storage storage.PostStorer
}

// POST /api/v1/posts
func (c Controller) Create(w http.ResponseWriter, r *http.Request) {
	var req Create
	if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
		w.WriteHeader(http.StatusInternalServerError)
		return
	}

	id := uuid.New().String()

	err := c.Storage.Insert(r.Context(), storage.PostWrite{
		UUID:      id,
		Subject:   fmt.Sprintf("%s - %d", req.Subject, time.Now().UTC().Nanosecond()),
		Text:      fmt.Sprintf("%s - %d", req.Text, time.Now().UTC().Nanosecond()),
		CreatedAt: time.Now(),
		TagUUIDs:  req.Tags,
	})
	if err != nil {
		switch err {
		case domain.ErrConflict:
			w.WriteHeader(http.StatusConflict)
		default:
			w.WriteHeader(http.StatusInternalServerError)
		}
		return
	}

	w.WriteHeader(http.StatusCreated)
	_, _ = w.Write([]byte(id))
}

// GET /api/v1/posts/:uuid
func (c Controller) Find(w http.ResponseWriter, r *http.Request) {
	id := httprouter.ParamsFromContext(r.Context()).ByName("uuid")

	com, err := c.Storage.Find(r.Context(), id)
	if err != nil {
		switch err {
		case domain.ErrNotFound:
			w.WriteHeader(http.StatusNotFound)
		default:
			w.WriteHeader(http.StatusInternalServerError)
		}
		return
	}

	res := Response{
		ID:        com.ID,
		UUID:      com.UUID,
		Subject:   com.Subject,
		Text:      com.Text,
		CreatedAt: com.CreatedAt,
		Tags:      make([]Tag, len(com.Tags)),
	}
	for i, tag := range com.Tags {
		res.Tags[i].ID = tag.ID
		res.Tags[i].UUID = tag.UUID
		res.Tags[i].Name = tag.Name
		res.Tags[i].Description = tag.Description
	}

	body, err := json.Marshal(res)
	if err != nil {
		w.WriteHeader(http.StatusInternalServerError)
		return
	}

	w.Header().Set("Content-Type", "application/json; charset=utf-8")
	_, _ = w.Write(body)
}

测试

查找

curl --request GET 'http://localhost:3000/api/v1/posts/acda18b2-f5c9-4736-8201-d1dc68c59354'

{
  "id": "604f4831601401a8de9dfe1a",
  "uuid": "acda18b2-f5c9-4736-8201-d1dc68c59354",
  "subject": "Tech in football",
  "text": "VAR ruined football so far!",
  "created_at": "2021-03-15T11:42:41.171Z",
  "tags": [
    {
      "ID": "604f463ed42b77e01e7f03dc",
      "uuid": "a1e57bc6-de13-432d-ad22-7a3f5ba07b80",
      "name": "sport",
      "description": "Any kind of sport"
    },
    {
      "ID": "604f463ed42b77e01e7f03db",
      "uuid": "07131782-0090-40cf-b530-68e774669826",
      "name": "tech",
      "description": "Technology related stuff"
    }
  ]
}

创建

curl --request POST 'http://localhost:3000/api/v1/posts' \
     --header 'Content-Type: application/json' \
     --data-raw '{
        "subject": "Subject 1",
        "text": "Text 1",
        "tags": [
            "07131782-0090-40cf-b530-68e774669826"
        ]
    }'

49329796-9f04-44d1-99fb-49128bba7a9c