@ -4,13 +4,17 @@
package unittest
import (
"fmt"
"math"
"os"
"strings"
"code.gitea.io/gitea/models/db"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"xorm.io/builder"
"xorm.io/xorm"
)
// Code in this file is mainly used by unittest.CheckConsistencyFor, which is not in the unit test for various reasons.
@ -51,22 +55,23 @@ func whereOrderConditions(e db.Engine, conditions []any) db.Engine {
return e . OrderBy ( orderBy )
}
// LoadBeanIfExists loads beans from fixture database if exist
func LoadBeanIfExists ( bean any , conditions ... any ) ( bool , error ) {
func getBeanIfExists ( bean any , conditions ... any ) ( bool , error ) {
e := db . GetEngine ( db . DefaultContext )
return whereOrderConditions ( e , conditions ) . Get ( bean )
}
// BeanExists for testing, check if a bean exists
func BeanExists ( t assert . TestingT , bean any , conditions ... any ) bool {
exists , err := LoadBeanIfExists ( bean , conditions ... )
assert . NoError ( t , err )
return exists
func GetBean [ T any ] ( t require . TestingT , bean T , conditions ... any ) ( ret T ) {
exists , err := getBeanIfExists ( bean , conditions ... )
require . NoError ( t , err )
if exists {
return bean
}
return ret
}
// AssertExistsAndLoadBean assert that a bean exists and load it from the test database
func AssertExistsAndLoadBean [ T any ] ( t require . TestingT , bean T , conditions ... any ) T {
exists , err := Load BeanIfExists( bean , conditions ... )
exists , err := get BeanIfExists( bean , conditions ... )
require . NoError ( t , err )
require . True ( t , exists ,
"Expected to find %+v (of type %T, with conditions %+v), but did not" ,
@ -112,25 +117,11 @@ func GetCount(t assert.TestingT, bean any, conditions ...any) int {
// AssertNotExistsBean assert that a bean does not exist in the test database
func AssertNotExistsBean ( t assert . TestingT , bean any , conditions ... any ) {
exists , err := Load BeanIfExists( bean , conditions ... )
exists , err := get BeanIfExists( bean , conditions ... )
assert . NoError ( t , err )
assert . False ( t , exists )
}
// AssertExistsIf asserts that a bean exists or does not exist, depending on
// what is expected.
func AssertExistsIf ( t assert . TestingT , expected bool , bean any , conditions ... any ) {
exists , err := LoadBeanIfExists ( bean , conditions ... )
assert . NoError ( t , err )
assert . Equal ( t , expected , exists )
}
// AssertSuccessfulInsert assert that beans is successfully inserted
func AssertSuccessfulInsert ( t assert . TestingT , beans ... any ) {
err := db . Insert ( db . DefaultContext , beans ... )
assert . NoError ( t , err )
}
// AssertCount assert the count of a bean
func AssertCount ( t assert . TestingT , bean , expected any ) bool {
return assert . EqualValues ( t , expected , GetCount ( t , bean ) )
@ -155,3 +146,39 @@ func AssertCountByCond(t assert.TestingT, tableName string, cond builder.Cond, e
return assert . EqualValues ( t , expected , GetCountByCond ( t , tableName , cond ) ,
"Failed consistency test, the counted bean (of table %s) was %+v" , tableName , cond )
}
// DumpQueryResult dumps the result of a query for debugging purpose
func DumpQueryResult ( t require . TestingT , sqlOrBean any , sqlArgs ... any ) {
x := db . GetEngine ( db . DefaultContext ) . ( * xorm . Engine )
goDB := x . DB ( ) . DB
sql , ok := sqlOrBean . ( string )
if ! ok {
sql = fmt . Sprintf ( "SELECT * FROM %s" , db . TableName ( sqlOrBean ) )
} else if ! strings . Contains ( sql , " " ) {
sql = fmt . Sprintf ( "SELECT * FROM %s" , sql )
}
rows , err := goDB . Query ( sql , sqlArgs ... )
require . NoError ( t , err )
defer rows . Close ( )
columns , err := rows . Columns ( )
require . NoError ( t , err )
_ , _ = fmt . Fprintf ( os . Stdout , "====== DumpQueryResult: %s ======\n" , sql )
idx := 0
for rows . Next ( ) {
row := make ( [ ] any , len ( columns ) )
rowPointers := make ( [ ] any , len ( columns ) )
for i := range row {
rowPointers [ i ] = & row [ i ]
}
require . NoError ( t , rows . Scan ( rowPointers ... ) )
_ , _ = fmt . Fprintf ( os . Stdout , "- # row[%d]\n" , idx )
for i , col := range columns {
_ , _ = fmt . Fprintf ( os . Stdout , " %s: %v\n" , col , row [ i ] )
}
idx ++
}
if idx == 0 {
_ , _ = fmt . Fprintf ( os . Stdout , "(no result, columns: %s)\n" , strings . Join ( columns , ", " ) )
}
}