diff --git a/graphql/documents/queries/settings/metadata.graphql b/graphql/documents/queries/settings/metadata.graphql index 9899f1dcd54..a9092b8ea6b 100644 --- a/graphql/documents/queries/settings/metadata.graphql +++ b/graphql/documents/queries/settings/metadata.graphql @@ -14,6 +14,10 @@ query MetadataGenerate($input: GenerateMetadataInput!) { metadataGenerate(input: $input) } +query MetadataAutoTag($input: AutoTagMetadataInput!) { + metadataAutoTag(input: $input) +} + query MetadataClean { metadataClean } diff --git a/graphql/schema/schema.graphql b/graphql/schema/schema.graphql index 69065813023..e9458425acb 100644 --- a/graphql/schema/schema.graphql +++ b/graphql/schema/schema.graphql @@ -75,6 +75,8 @@ type Query { metadataScan(input: ScanMetadataInput!): String! """Start generating content. Returns the job ID""" metadataGenerate(input: GenerateMetadataInput!): String! + """Start auto-tagging. Returns the job ID""" + metadataAutoTag(input: AutoTagMetadataInput!): String! """Clean metadata. Returns the job ID""" metadataClean: String! diff --git a/graphql/schema/types/metadata.graphql b/graphql/schema/types/metadata.graphql index 025d3e11f69..6af77486a99 100644 --- a/graphql/schema/types/metadata.graphql +++ b/graphql/schema/types/metadata.graphql @@ -9,6 +9,15 @@ input ScanMetadataInput { nameFromMetadata: Boolean! } +input AutoTagMetadataInput { + """IDs of performers to tag files with, or "*" for all""" + performers: [String!] + """IDs of studios to tag files with, or "*" for all""" + studios: [String!] + """IDs of tags to tag files with, or "*" for all""" + tags: [String!] +} + type MetadataUpdateStatus { progress: Float! status: String! diff --git a/pkg/api/resolver_query_metadata.go b/pkg/api/resolver_query_metadata.go index 8acb708aecf..3d54ee40f09 100644 --- a/pkg/api/resolver_query_metadata.go +++ b/pkg/api/resolver_query_metadata.go @@ -27,6 +27,11 @@ func (r *queryResolver) MetadataGenerate(ctx context.Context, input models.Gener return "todo", nil } +func (r *queryResolver) MetadataAutoTag(ctx context.Context, input models.AutoTagMetadataInput) (string, error) { + manager.GetInstance().AutoTag(input.Performers, input.Studios, input.Tags) + return "todo", nil +} + func (r *queryResolver) MetadataClean(ctx context.Context) (string, error) { manager.GetInstance().Clean() return "todo", nil diff --git a/pkg/manager/job_status.go b/pkg/manager/job_status.go index f412c8dc6b1..a1a57802e81 100644 --- a/pkg/manager/job_status.go +++ b/pkg/manager/job_status.go @@ -10,6 +10,7 @@ const ( Generate JobStatus = 4 Clean JobStatus = 5 Scrape JobStatus = 6 + AutoTag JobStatus = 7 ) func (s JobStatus) String() string { @@ -26,6 +27,8 @@ func (s JobStatus) String() string { statusMessage = "Scan" case Generate: statusMessage = "Generate" + case AutoTag: + statusMessage = "Auto Tag" } return statusMessage diff --git a/pkg/manager/manager_tasks.go b/pkg/manager/manager_tasks.go index 0597051a1a0..0b4483518c3 100644 --- a/pkg/manager/manager_tasks.go +++ b/pkg/manager/manager_tasks.go @@ -2,6 +2,7 @@ package manager import ( "path/filepath" + "strconv" "sync" "time" @@ -17,6 +18,8 @@ type TaskStatus struct { Progress float64 LastUpdate time.Time stopping bool + upTo int + total int } func (t *TaskStatus) Stop() bool { @@ -34,10 +37,16 @@ func (t *TaskStatus) setProgress(upTo int, total int) { if total == 0 { t.Progress = 1 } + t.upTo = upTo + t.total = total t.Progress = float64(upTo) / float64(total) t.updated() } +func (t *TaskStatus) incrementProgress() { + t.setProgress(t.upTo+1, t.total) +} + func (t *TaskStatus) indefiniteProgress() { t.Progress = -1 t.updated() @@ -202,6 +211,172 @@ func (s *singleton) Generate(sprites bool, previews bool, markers bool, transcod }() } +func (s *singleton) AutoTag(performerIds []string, studioIds []string, tagIds []string) { + if s.Status.Status != Idle { + return + } + s.Status.SetStatus(AutoTag) + s.Status.indefiniteProgress() + + go func() { + defer s.returnToIdleState() + + // calculate work load + performerCount := len(performerIds) + studioCount := len(studioIds) + tagCount := len(tagIds) + + performerQuery := models.NewPerformerQueryBuilder() + studioQuery := models.NewTagQueryBuilder() + tagQuery := models.NewTagQueryBuilder() + + const wildcard = "*" + var err error + if performerCount == 1 && performerIds[0] == wildcard { + performerCount, err = performerQuery.Count() + if err != nil { + logger.Errorf("Error getting performer count: %s", err.Error()) + } + } + if studioCount == 1 && studioIds[0] == wildcard { + studioCount, err = studioQuery.Count() + if err != nil { + logger.Errorf("Error getting studio count: %s", err.Error()) + } + } + if tagCount == 1 && tagIds[0] == wildcard { + tagCount, err = tagQuery.Count() + if err != nil { + logger.Errorf("Error getting tag count: %s", err.Error()) + } + } + + total := performerCount + studioCount + tagCount + s.Status.setProgress(0, total) + + s.autoTagPerformers(performerIds) + s.autoTagStudios(studioIds) + s.autoTagTags(tagIds) + }() +} + +func (s *singleton) autoTagPerformers(performerIds []string) { + performerQuery := models.NewPerformerQueryBuilder() + + var wg sync.WaitGroup + for _, performerId := range performerIds { + var performers []*models.Performer + if performerId == "*" { + var err error + performers, err = performerQuery.All() + if err != nil { + logger.Errorf("Error querying performers: %s", err.Error()) + continue + } + } else { + performerIdInt, err := strconv.Atoi(performerId) + if err != nil { + logger.Errorf("Error parsing performer id %s: %s", performerId, err.Error()) + continue + } + + performer, err := performerQuery.Find(performerIdInt) + if err != nil { + logger.Errorf("Error finding performer id %s: %s", performerId, err.Error()) + continue + } + performers = append(performers, performer) + } + + for _, performer := range performers { + wg.Add(1) + task := AutoTagPerformerTask{performer: performer} + go task.Start(&wg) + wg.Wait() + + s.Status.incrementProgress() + } + } +} + +func (s *singleton) autoTagStudios(studioIds []string) { + studioQuery := models.NewStudioQueryBuilder() + + var wg sync.WaitGroup + for _, studioId := range studioIds { + var studios []*models.Studio + if studioId == "*" { + var err error + studios, err = studioQuery.All() + if err != nil { + logger.Errorf("Error querying studios: %s", err.Error()) + continue + } + } else { + studioIdInt, err := strconv.Atoi(studioId) + if err != nil { + logger.Errorf("Error parsing studio id %s: %s", studioId, err.Error()) + continue + } + + studio, err := studioQuery.Find(studioIdInt, nil) + if err != nil { + logger.Errorf("Error finding studio id %s: %s", studioId, err.Error()) + continue + } + studios = append(studios, studio) + } + + for _, studio := range studios { + wg.Add(1) + task := AutoTagStudioTask{studio: studio} + go task.Start(&wg) + wg.Wait() + + s.Status.incrementProgress() + } + } +} + +func (s *singleton) autoTagTags(tagIds []string) { + tagQuery := models.NewTagQueryBuilder() + + var wg sync.WaitGroup + for _, tagId := range tagIds { + var tags []*models.Tag + if tagId == "*" { + var err error + tags, err = tagQuery.All() + if err != nil { + logger.Errorf("Error querying tags: %s", err.Error()) + continue + } + } else { + tagIdInt, err := strconv.Atoi(tagId) + if err != nil { + logger.Errorf("Error parsing tag id %s: %s", tagId, err.Error()) + continue + } + + tag, err := tagQuery.Find(tagIdInt, nil) + if err != nil { + logger.Errorf("Error finding tag id %s: %s", tagId, err.Error()) + continue + } + tags = append(tags, tag) + } + + for _, tag := range tags { + wg.Add(1) + task := AutoTagTagTask{tag: tag} + go task.Start(&wg) + wg.Wait() + + s.Status.incrementProgress() + } + } +} + func (s *singleton) Clean() { if s.Status.Status != Idle { return diff --git a/pkg/manager/task_autotag.go b/pkg/manager/task_autotag.go new file mode 100644 index 00000000000..4709c6f6ff8 --- /dev/null +++ b/pkg/manager/task_autotag.go @@ -0,0 +1,171 @@ +package manager + +import ( + "context" + "database/sql" + "strings" + "sync" + + "github.com/stashapp/stash/pkg/database" + "github.com/stashapp/stash/pkg/logger" + "github.com/stashapp/stash/pkg/models" +) + +type AutoTagPerformerTask struct { + performer *models.Performer +} + +func (t *AutoTagPerformerTask) Start(wg *sync.WaitGroup) { + defer wg.Done() + + t.autoTagPerformer() +} + +func getQueryRegex(name string) string { + const separatorChars = `.\-_ ` + // handle path separators + const endSeparatorChars = separatorChars + `\\/` + const separator = `[` + separatorChars + `]` + const endSeparator = `[` + endSeparatorChars + `]` + + ret := strings.Replace(name, " ", separator+"*", -1) + ret = "(?:^|" + endSeparator + "+)" + ret + "(?:$|" + endSeparator + "+)" + return ret +} + +func (t *AutoTagPerformerTask) autoTagPerformer() { + qb := models.NewSceneQueryBuilder() + jqb := models.NewJoinsQueryBuilder() + + regex := getQueryRegex(t.performer.Name.String) + + scenes, err := qb.QueryAllByPathRegex(regex) + + if err != nil { + logger.Infof("Error querying scenes with regex '%s': %s", regex, err.Error()) + return + } + + ctx := context.TODO() + tx := database.DB.MustBeginTx(ctx, nil) + + for _, scene := range scenes { + added, err := jqb.AddPerformerScene(scene.ID, t.performer.ID, tx) + + if err != nil { + logger.Infof("Error adding performer '%s' to scene '%s': %s", t.performer.Name.String, scene.GetTitle(), err.Error()) + tx.Rollback() + return + } + + if added { + logger.Infof("Added performer '%s' to scene '%s'", t.performer.Name.String, scene.GetTitle()) + } + } + + if err := tx.Commit(); err != nil { + logger.Infof("Error adding performer to scene: %s", err.Error()) + return + } +} + +type AutoTagStudioTask struct { + studio *models.Studio +} + +func (t *AutoTagStudioTask) Start(wg *sync.WaitGroup) { + defer wg.Done() + + t.autoTagStudio() +} + +func (t *AutoTagStudioTask) autoTagStudio() { + qb := models.NewSceneQueryBuilder() + + regex := getQueryRegex(t.studio.Name.String) + + scenes, err := qb.QueryAllByPathRegex(regex) + + if err != nil { + logger.Infof("Error querying scenes with regex '%s': %s", regex, err.Error()) + return + } + + ctx := context.TODO() + tx := database.DB.MustBeginTx(ctx, nil) + + for _, scene := range scenes { + if scene.StudioID.Int64 == int64(t.studio.ID) { + // don't modify + continue + } + + logger.Infof("Adding studio '%s' to scene '%s'", t.studio.Name.String, scene.GetTitle()) + + // set the studio id + studioID := sql.NullInt64{Int64: int64(t.studio.ID), Valid: true} + scenePartial := models.ScenePartial{ + ID: scene.ID, + StudioID: &studioID, + } + + _, err := qb.Update(scenePartial, tx) + + if err != nil { + logger.Infof("Error adding studio to scene: %s", err.Error()) + tx.Rollback() + return + } + } + + if err := tx.Commit(); err != nil { + logger.Infof("Error adding studio to scene: %s", err.Error()) + return + } +} + +type AutoTagTagTask struct { + tag *models.Tag +} + +func (t *AutoTagTagTask) Start(wg *sync.WaitGroup) { + defer wg.Done() + + t.autoTagTag() +} + +func (t *AutoTagTagTask) autoTagTag() { + qb := models.NewSceneQueryBuilder() + jqb := models.NewJoinsQueryBuilder() + + regex := getQueryRegex(t.tag.Name) + + scenes, err := qb.QueryAllByPathRegex(regex) + + if err != nil { + logger.Infof("Error querying scenes with regex '%s': %s", regex, err.Error()) + return + } + + ctx := context.TODO() + tx := database.DB.MustBeginTx(ctx, nil) + + for _, scene := range scenes { + added, err := jqb.AddSceneTag(scene.ID, t.tag.ID, tx) + + if err != nil { + logger.Infof("Error adding tag '%s' to scene '%s': %s", t.tag.Name, scene.GetTitle(), err.Error()) + tx.Rollback() + return + } + + if added { + logger.Infof("Added tag '%s' to scene '%s'", t.tag.Name, scene.GetTitle()) + } + } + + if err := tx.Commit(); err != nil { + logger.Infof("Error adding tag to scene: %s", err.Error()) + return + } +} diff --git a/pkg/manager/task_autotag_test.go b/pkg/manager/task_autotag_test.go new file mode 100644 index 00000000000..ab86058ce27 --- /dev/null +++ b/pkg/manager/task_autotag_test.go @@ -0,0 +1,339 @@ +// +build integration + +package manager + +import ( + "context" + "database/sql" + "fmt" + "io/ioutil" + "os" + "strings" + "sync" + "testing" + + "github.com/stashapp/stash/pkg/database" + "github.com/stashapp/stash/pkg/models" + "github.com/stashapp/stash/pkg/utils" + + _ "github.com/golang-migrate/migrate/v4/database/sqlite3" + _ "github.com/golang-migrate/migrate/v4/source/file" + "github.com/jmoiron/sqlx" +) + +const testName = "Foo Bar" +const testExtension = ".mp4" + +var testSeparators = []string{ + ".", + "-", + "_", + " ", +} + +func generateNamePatterns(name string, separator string) []string { + var ret []string + ret = append(ret, fmt.Sprintf("%s%saaa"+testExtension, name, separator)) + ret = append(ret, fmt.Sprintf("aaa%s%s"+testExtension, separator, name)) + ret = append(ret, fmt.Sprintf("aaa%s%s%sbbb"+testExtension, separator, name, separator)) + ret = append(ret, fmt.Sprintf("dir/%s%saaa"+testExtension, name, separator)) + ret = append(ret, fmt.Sprintf("dir\\%s%saaa"+testExtension, name, separator)) + ret = append(ret, fmt.Sprintf("%s%saaa/dir/bbb"+testExtension, name, separator)) + ret = append(ret, fmt.Sprintf("%s%saaa\\dir\\bbb"+testExtension, name, separator)) + ret = append(ret, fmt.Sprintf("dir/%s%s/aaa"+testExtension, name, separator)) + ret = append(ret, fmt.Sprintf("dir\\%s%s\\aaa"+testExtension, name, separator)) + + return ret +} + +func generateFalseNamePattern(name string, separator string) string { + splitted := strings.Split(name, " ") + + return fmt.Sprintf("%s%saaa%s%s"+testExtension, splitted[0], separator, separator, splitted[1]) +} + +func testTeardown(databaseFile string) { + err := database.DB.Close() + + if err != nil { + panic(err) + } + + err = os.Remove(databaseFile) + if err != nil { + panic(err) + } +} + +func runTests(m *testing.M) int { + // create the database file + f, err := ioutil.TempFile("", "*.sqlite") + if err != nil { + panic(fmt.Sprintf("Could not create temporary file: %s", err.Error())) + } + + f.Close() + databaseFile := f.Name() + database.Initialize(databaseFile) + + // defer close and delete the database + defer testTeardown(databaseFile) + + err = populateDB() + if err != nil { + panic(fmt.Sprintf("Could not populate database: %s", err.Error())) + } else { + // run the tests + return m.Run() + } +} + +func TestMain(m *testing.M) { + ret := runTests(m) + os.Exit(ret) +} + +func createPerformer(tx *sqlx.Tx) error { + // create the performer + pqb := models.NewPerformerQueryBuilder() + + performer := models.Performer{ + Image: []byte{0, 1, 2}, + Checksum: testName, + Name: sql.NullString{Valid: true, String: testName}, + Favorite: sql.NullBool{Valid: true, Bool: false}, + } + + _, err := pqb.Create(performer, tx) + if err != nil { + return err + } + + return nil +} + +func createStudio(tx *sqlx.Tx) error { + // create the studio + qb := models.NewStudioQueryBuilder() + + studio := models.Studio{ + Image: []byte{0, 1, 2}, + Checksum: testName, + Name: sql.NullString{Valid: true, String: testName}, + } + + _, err := qb.Create(studio, tx) + if err != nil { + return err + } + + return nil +} + +func createTag(tx *sqlx.Tx) error { + // create the studio + qb := models.NewTagQueryBuilder() + + tag := models.Tag{ + Name: testName, + } + + _, err := qb.Create(tag, tx) + if err != nil { + return err + } + + return nil +} + +func createScenes(tx *sqlx.Tx) error { + sqb := models.NewSceneQueryBuilder() + + // create the scenes + var scenePatterns []string + var falseScenePatterns []string + for _, separator := range testSeparators { + scenePatterns = append(scenePatterns, generateNamePatterns(testName, separator)...) + scenePatterns = append(scenePatterns, generateNamePatterns(strings.ToLower(testName), separator)...) + if separator != " " { + scenePatterns = append(scenePatterns, generateNamePatterns(strings.Replace(testName, " ", separator, -1), separator)...) + } + falseScenePatterns = append(falseScenePatterns, generateFalseNamePattern(testName, separator)) + } + + for _, fn := range scenePatterns { + err := createScene(sqb, tx, fn, true) + if err != nil { + return err + } + } + for _, fn := range falseScenePatterns { + err := createScene(sqb, tx, fn, false) + if err != nil { + return err + } + } + + return nil +} + +func createScene(sqb models.SceneQueryBuilder, tx *sqlx.Tx, name string, expectedResult bool) error { + scene := models.Scene{ + Checksum: utils.MD5FromString(name), + Path: name, + } + + // if expectedResult is true then we expect it to match, set the title accordingly + if expectedResult { + scene.Title = sql.NullString{Valid: true, String: name} + } + + _, err := sqb.Create(scene, tx) + + if err != nil { + return fmt.Errorf("Failed to create scene with name '%s': %s", name, err.Error()) + } + + return nil +} + +func populateDB() error { + ctx := context.TODO() + tx := database.DB.MustBeginTx(ctx, nil) + + err := createPerformer(tx) + if err != nil { + return err + } + + err = createStudio(tx) + if err != nil { + return err + } + + err = createTag(tx) + if err != nil { + return err + } + + err = createScenes(tx) + if err != nil { + return err + } + + if err := tx.Commit(); err != nil { + return err + } + + return nil +} + +func TestParsePerformers(t *testing.T) { + pqb := models.NewPerformerQueryBuilder() + performers, err := pqb.All() + + if err != nil { + t.Errorf("Error getting performer: %s", err) + return + } + + task := AutoTagPerformerTask{ + performer: performers[0], + } + + var wg sync.WaitGroup + wg.Add(1) + task.Start(&wg) + + // verify that scenes were tagged correctly + sqb := models.NewSceneQueryBuilder() + + scenes, err := sqb.All() + + for _, scene := range scenes { + performers, err := pqb.FindBySceneID(scene.ID, nil) + + if err != nil { + t.Errorf("Error getting scene performers: %s", err.Error()) + return + } + + // title is only set on scenes where we expect performer to be set + if scene.Title.String == scene.Path && len(performers) == 0 { + t.Errorf("Did not set performer '%s' for path '%s'", testName, scene.Path) + } else if scene.Title.String != scene.Path && len(performers) > 0 { + t.Errorf("Incorrectly set performer '%s' for path '%s'", testName, scene.Path) + } + } +} + +func TestParseStudios(t *testing.T) { + studioQuery := models.NewStudioQueryBuilder() + studios, err := studioQuery.All() + + if err != nil { + t.Errorf("Error getting studio: %s", err) + return + } + + task := AutoTagStudioTask{ + studio: studios[0], + } + + var wg sync.WaitGroup + wg.Add(1) + task.Start(&wg) + + // verify that scenes were tagged correctly + sqb := models.NewSceneQueryBuilder() + + scenes, err := sqb.All() + + for _, scene := range scenes { + // title is only set on scenes where we expect studio to be set + if scene.Title.String == scene.Path && scene.StudioID.Int64 != int64(studios[0].ID) { + t.Errorf("Did not set studio '%s' for path '%s'", testName, scene.Path) + } else if scene.Title.String != scene.Path && scene.StudioID.Int64 == int64(studios[0].ID) { + t.Errorf("Incorrectly set studio '%s' for path '%s'", testName, scene.Path) + } + } +} + +func TestParseTags(t *testing.T) { + tagQuery := models.NewTagQueryBuilder() + tags, err := tagQuery.All() + + if err != nil { + t.Errorf("Error getting performer: %s", err) + return + } + + task := AutoTagTagTask{ + tag: tags[0], + } + + var wg sync.WaitGroup + wg.Add(1) + task.Start(&wg) + + // verify that scenes were tagged correctly + sqb := models.NewSceneQueryBuilder() + + scenes, err := sqb.All() + + for _, scene := range scenes { + tags, err := tagQuery.FindBySceneID(scene.ID, nil) + + if err != nil { + t.Errorf("Error getting scene tags: %s", err.Error()) + return + } + + // title is only set on scenes where we expect performer to be set + if scene.Title.String == scene.Path && len(tags) == 0 { + t.Errorf("Did not set tag '%s' for path '%s'", testName, scene.Path) + } else if scene.Title.String != scene.Path && len(tags) > 0 { + t.Errorf("Incorrectly set tag '%s' for path '%s'", testName, scene.Path) + } + } +} diff --git a/pkg/models/model_scene.go b/pkg/models/model_scene.go index 2f487488f8e..8097bbc3fa7 100644 --- a/pkg/models/model_scene.go +++ b/pkg/models/model_scene.go @@ -2,6 +2,7 @@ package models import ( "database/sql" + "path/filepath" ) type Scene struct { @@ -27,7 +28,7 @@ type Scene struct { } type ScenePartial struct { - ID int `db:"id" json:"id"` + ID int `db:"id" json:"id"` Checksum *string `db:"checksum" json:"checksum"` Path *string `db:"path" json:"path"` Title *sql.NullString `db:"title" json:"title"` @@ -47,3 +48,11 @@ type ScenePartial struct { CreatedAt *SQLiteTimestamp `db:"created_at" json:"created_at"` UpdatedAt *SQLiteTimestamp `db:"updated_at" json:"updated_at"` } + +func (s Scene) GetTitle() string { + if s.Title.String != "" { + return s.Title.String + } + + return filepath.Base(s.Path) +} diff --git a/pkg/models/querybuilder_joins.go b/pkg/models/querybuilder_joins.go index a16961ca780..310bc8dad8f 100644 --- a/pkg/models/querybuilder_joins.go +++ b/pkg/models/querybuilder_joins.go @@ -1,6 +1,11 @@ package models -import "github.com/jmoiron/sqlx" +import ( + "database/sql" + + "github.com/jmoiron/sqlx" + "github.com/stashapp/stash/pkg/database" +) type JoinsQueryBuilder struct{} @@ -8,6 +13,41 @@ func NewJoinsQueryBuilder() JoinsQueryBuilder { return JoinsQueryBuilder{} } +func (qb *JoinsQueryBuilder) GetScenePerformers(sceneID int, tx *sqlx.Tx) ([]PerformersScenes, error) { + ensureTx(tx) + + // Delete the existing joins and then create new ones + query := `SELECT * from performers_scenes WHERE scene_id = ?` + + var rows *sqlx.Rows + var err error + if tx != nil { + rows, err = tx.Queryx(query, sceneID) + } else { + rows, err = database.DB.Queryx(query, sceneID) + } + + if err != nil && err != sql.ErrNoRows { + return nil, err + } + defer rows.Close() + + performerScenes := make([]PerformersScenes, 0) + for rows.Next() { + performerScene := PerformersScenes{} + if err := rows.StructScan(&performerScene); err != nil { + return nil, err + } + performerScenes = append(performerScenes, performerScene) + } + + if err := rows.Err(); err != nil { + return nil, err + } + + return performerScenes, nil +} + func (qb *JoinsQueryBuilder) CreatePerformersScenes(newJoins []PerformersScenes, tx *sqlx.Tx) error { ensureTx(tx) for _, join := range newJoins { @@ -22,6 +62,36 @@ func (qb *JoinsQueryBuilder) CreatePerformersScenes(newJoins []PerformersScenes, return nil } +// AddPerformerScene adds a performer to a scene. It does not make any change +// if the performer already exists on the scene. It returns true if scene +// performer was added. +func (qb *JoinsQueryBuilder) AddPerformerScene(sceneID int, performerID int, tx *sqlx.Tx) (bool, error) { + ensureTx(tx) + + existingPerformers, err := qb.GetScenePerformers(sceneID, tx) + + if err != nil { + return false, err + } + + // ensure not already present + for _, p := range existingPerformers { + if p.PerformerID == performerID && p.SceneID == sceneID { + return false, nil + } + } + + performerJoin := PerformersScenes{ + PerformerID: performerID, + SceneID: sceneID, + } + performerJoins := append(existingPerformers, performerJoin) + + err = qb.UpdatePerformersScenes(sceneID, performerJoins, tx) + + return err == nil, err +} + func (qb *JoinsQueryBuilder) UpdatePerformersScenes(sceneID int, updatedJoins []PerformersScenes, tx *sqlx.Tx) error { ensureTx(tx) @@ -41,6 +111,41 @@ func (qb *JoinsQueryBuilder) DestroyPerformersScenes(sceneID int, tx *sqlx.Tx) e return err } +func (qb *JoinsQueryBuilder) GetSceneTags(sceneID int, tx *sqlx.Tx) ([]ScenesTags, error) { + ensureTx(tx) + + // Delete the existing joins and then create new ones + query := `SELECT * from scenes_tags WHERE scene_id = ?` + + var rows *sqlx.Rows + var err error + if tx != nil { + rows, err = tx.Queryx(query, sceneID) + } else { + rows, err = database.DB.Queryx(query, sceneID) + } + + if err != nil && err != sql.ErrNoRows { + return nil, err + } + defer rows.Close() + + sceneTags := make([]ScenesTags, 0) + for rows.Next() { + sceneTag := ScenesTags{} + if err := rows.StructScan(&sceneTag); err != nil { + return nil, err + } + sceneTags = append(sceneTags, sceneTag) + } + + if err := rows.Err(); err != nil { + return nil, err + } + + return sceneTags, nil +} + func (qb *JoinsQueryBuilder) CreateScenesTags(newJoins []ScenesTags, tx *sqlx.Tx) error { ensureTx(tx) for _, join := range newJoins { @@ -66,6 +171,35 @@ func (qb *JoinsQueryBuilder) UpdateScenesTags(sceneID int, updatedJoins []Scenes return qb.CreateScenesTags(updatedJoins, tx) } +// AddSceneTag adds a tag to a scene. It does not make any change if the tag +// already exists on the scene. It returns true if scene tag was added. +func (qb *JoinsQueryBuilder) AddSceneTag(sceneID int, tagID int, tx *sqlx.Tx) (bool, error) { + ensureTx(tx) + + existingTags, err := qb.GetSceneTags(sceneID, tx) + + if err != nil { + return false, err + } + + // ensure not already present + for _, p := range existingTags { + if p.TagID == tagID && p.SceneID == sceneID { + return false, nil + } + } + + tagJoin := ScenesTags{ + TagID: tagID, + SceneID: sceneID, + } + tagJoins := append(existingTags, tagJoin) + + err = qb.UpdateScenesTags(sceneID, tagJoins, tx) + + return err == nil, err +} + func (qb *JoinsQueryBuilder) DestroyScenesTags(sceneID int, tx *sqlx.Tx) error { ensureTx(tx) diff --git a/pkg/models/querybuilder_scene.go b/pkg/models/querybuilder_scene.go index 34ea9f5e5b4..621101b4399 100644 --- a/pkg/models/querybuilder_scene.go +++ b/pkg/models/querybuilder_scene.go @@ -291,6 +291,30 @@ func getMultiCriterionClause(table string, joinTable string, joinTableField stri return whereClause, havingClause } +func (qb *SceneQueryBuilder) QueryAllByPathRegex(regex string) ([]*Scene, error) { + var args []interface{} + body := selectDistinctIDs("scenes") + " WHERE scenes.path regexp '(?i)" + regex + "'" + + idsResult, err := runIdsQuery(body, args) + + if err != nil { + return nil, err + } + + var scenes []*Scene + for _, id := range idsResult { + scene, err := qb.Find(id) + + if err != nil { + return nil, err + } + + scenes = append(scenes, scene) + } + + return scenes, nil +} + func (qb *SceneQueryBuilder) QueryByPathRegex(findFilter *FindFilterType) ([]*Scene, int) { if findFilter == nil { findFilter = &FindFilterType{} diff --git a/ui/v2/src/components/Settings/SettingsTasksPanel/SettingsTasksPanel.tsx b/ui/v2/src/components/Settings/SettingsTasksPanel/SettingsTasksPanel.tsx index 5ef1c2fe78d..c3334331374 100644 --- a/ui/v2/src/components/Settings/SettingsTasksPanel/SettingsTasksPanel.tsx +++ b/ui/v2/src/components/Settings/SettingsTasksPanel/SettingsTasksPanel.tsx @@ -25,6 +25,10 @@ export const SettingsTasksPanel: FunctionComponent = (props: IProps) => const [status, setStatus] = useState(""); const [progress, setProgress] = useState(undefined); + const [autoTagPerformers, setAutoTagPerformers] = useState(true); + const [autoTagStudios, setAutoTagStudios] = useState(true); + const [autoTagTags, setAutoTagTags] = useState(true); + const jobStatus = StashService.useJobStatus(); const metadataUpdate = StashService.useMetadataUpdate(); @@ -42,6 +46,8 @@ export const SettingsTasksPanel: FunctionComponent = (props: IProps) => return "Exporting to JSON"; case "Import": return "Importing from JSON"; + case "Auto Tag": + return "Auto tagging scenes"; } return "Idle"; @@ -130,6 +136,25 @@ export const SettingsTasksPanel: FunctionComponent = (props: IProps) => } } + function getAutoTagInput() { + var wildcard = ["*"]; + return { + performers: autoTagPerformers ? wildcard : [], + studios: autoTagStudios ? wildcard : [], + tags: autoTagTags ? wildcard : [] + } + } + + async function onAutoTag() { + try { + await StashService.queryMetadataAutoTag(getAutoTagInput()); + ToastUtils.success("Started auto tagging"); + jobStatus.refetch(); + } catch (e) { + ErrorUtils.handle(e); + } + } + function maybeRenderStop() { if (!status || status === "Idle") { return undefined; @@ -180,11 +205,38 @@ export const SettingsTasksPanel: FunctionComponent = (props: IProps) => /> ) + } + } + function renderScenesButton() { if (props.isEditing) { return; } let linkSrc: string = "#"; @@ -136,6 +146,7 @@ export const DetailsEditNavbar: FunctionComponent = (props: IProps) => { {renderImageInput()} {renderSaveButton()} + {renderAutoTagButton()} {renderScenesButton()} {renderDeleteButton()} diff --git a/ui/v2/src/components/Studios/StudioDetails/Studio.tsx b/ui/v2/src/components/Studios/StudioDetails/Studio.tsx index cb7b3f0d3e8..7ae724c9969 100644 --- a/ui/v2/src/components/Studios/StudioDetails/Studio.tsx +++ b/ui/v2/src/components/Studios/StudioDetails/Studio.tsx @@ -15,6 +15,7 @@ import { IBaseProps } from "../../../models"; import { ErrorUtils } from "../../../utils/errors"; import { TableUtils } from "../../../utils/table"; import { DetailsEditNavbar } from "../../Shared/DetailsEditNavbar"; +import { ToastUtils } from "../../../utils/toasts"; interface IProps extends IBaseProps {} @@ -96,6 +97,18 @@ export const Studio: FunctionComponent = (props: IProps) => { setIsLoading(false); } + async function onAutoTag() { + if (!studio || !studio.id) { + return; + } + try { + await StashService.queryMetadataAutoTag({ studios: [studio.id]}); + ToastUtils.success("Started auto tagging"); + } catch (e) { + ErrorUtils.handle(e); + } + } + async function onDelete() { setIsLoading(true); try { @@ -135,6 +148,7 @@ export const Studio: FunctionComponent = (props: IProps) => { onToggleEdit={() => { setIsEditing(!isEditing); updateStudioEditState(studio); }} onSave={onSave} onDelete={onDelete} + onAutoTag={onAutoTag} onImageChange={onImageChange} />

diff --git a/ui/v2/src/components/Tags/TagList.tsx b/ui/v2/src/components/Tags/TagList.tsx index aae49b025c5..a7ddf1cd93e 100644 --- a/ui/v2/src/components/Tags/TagList.tsx +++ b/ui/v2/src/components/Tags/TagList.tsx @@ -77,6 +77,18 @@ export const TagList: FunctionComponent = (props: IProps) => { } } + async function onAutoTag(tag : GQL.TagDataFragment) { + if (!tag) { + return; + } + try { + await StashService.queryMetadataAutoTag({ tags: [tag.id]}); + ToastUtils.success("Started auto tagging"); + } catch (e) { + ErrorUtils.handle(e); + } + } + async function onDelete() { try { await deleteTag(); @@ -115,6 +127,7 @@ export const TagList: FunctionComponent = (props: IProps) => {
setEditingTag(tag)}>{tag.name}
+ Scenes: {tag.scene_count} Markers: {tag.scene_marker_count} diff --git a/ui/v2/src/components/performers/PerformerDetails/Performer.tsx b/ui/v2/src/components/performers/PerformerDetails/Performer.tsx index 26f035c6cfd..51cf2f79cf2 100644 --- a/ui/v2/src/components/performers/PerformerDetails/Performer.tsx +++ b/ui/v2/src/components/performers/PerformerDetails/Performer.tsx @@ -15,6 +15,7 @@ import { ErrorUtils } from "../../../utils/errors"; import { TableUtils } from "../../../utils/table"; import { ScrapePerformerSuggest } from "../../select/ScrapePerformerSuggest"; import { DetailsEditNavbar } from "../../Shared/DetailsEditNavbar"; +import { ToastUtils } from "../../../utils/toasts"; interface IPerformerProps extends IBaseProps {} @@ -171,6 +172,18 @@ export const Performer: FunctionComponent = (props: IPerformerP props.history.push(`/performers`); } + async function onAutoTag() { + if (!performer || !performer.id) { + return; + } + try { + await StashService.queryMetadataAutoTag({ performers: [performer.id]}); + ToastUtils.success("Started auto tagging"); + } catch (e) { + ErrorUtils.handle(e); + } + } + function onImageChange(event: React.FormEvent) { const file: File = (event.target as any).files[0]; const reader: FileReader = new FileReader(); @@ -315,6 +328,7 @@ export const Performer: FunctionComponent = (props: IPerformerP onImageChange={onImageChange} scrapers={queryableScrapers} onDisplayScraperDialog={onDisplayFreeOnesDialog} + onAutoTag={onAutoTag} />

({ + query: GQL.MetadataAutoTagDocument, + variables: { input }, + fetchPolicy: "network-only", + }); + } + public static queryMetadataGenerate(input: GQL.GenerateMetadataInput) { return StashService.client.query({ query: GQL.MetadataGenerateDocument,