From 91013ab5239c8a638fa0082a745ef212b1a251d7 Mon Sep 17 00:00:00 2001 From: lj-wsdj <1134294381@qq.com> Date: Tue, 6 Aug 2024 14:06:23 +0800 Subject: [PATCH] fix --- db/add_sql.go | 59 +++++++++++++++++++++++++++++++++++++++++++++++++-- db/db.go | 3 ++- main_test.go | 9 ++++++-- 3 files changed, 66 insertions(+), 5 deletions(-) diff --git a/db/add_sql.go b/db/add_sql.go index e4d303a..6f3218b 100644 --- a/db/add_sql.go +++ b/db/add_sql.go @@ -1,8 +1,10 @@ package db import ( + "database/sql" "fmt" "reflect" + "regexp" "strings" "git.botann.com/lijun/sql-builder/util" @@ -17,7 +19,11 @@ func (s *SqlBuilder[T]) Conditions(cond string, args ...interface{}) { if strings.HasPrefix(cond, " and") { cond = strings.Join(strings.Split(cond, " and")[1:], " and ") } - cond, args = dealInCondition(cond, args) + if s.Driver == util.DriverOracle { + cond, args = dealOralceInCondition(cond, args) + } else { + cond, args = dealInCondition(cond, args) + } s.Opts.conditions = append(s.Opts.conditions, cond) s.Opts.args = append(s.Opts.args, args...) s.buildCondition(util.ConditionPlaceholder, cond) @@ -27,7 +33,11 @@ func (s *SqlBuilder[T]) InsertConditions(symbol string, cond string, args ...int if strings.HasPrefix(cond, " and") { cond = strings.Join(strings.Split(cond, " and")[1:], " and ") } - cond, args = dealInCondition(cond, args) + if s.Driver == util.DriverOracle { + cond, args = dealOralceInCondition(cond, args) + } else { + cond, args = dealInCondition(cond, args) + } s.Opts.conditions = append(s.Opts.conditions, cond) s.Opts.args = append(s.Opts.args, args...) s.buildCondition(symbol, cond) @@ -95,3 +105,48 @@ func dealInCondition(cond string, args []interface{}) (string, []interface{}) { } return cond, newArgs } + +func dealOralceInCondition(cond string, args []interface{}) (string, []interface{}) { + if ok := strings.Contains(cond, "(:in)"); !ok { + return dealOracleCondition(cond, args) + } + var newArgs []interface{} + for i := 0; i < len(args); i++ { + //判断args参数是否为slice + if reflect.TypeOf(args[i]).Kind() == reflect.Slice { + var str []string + for j := 0; j < reflect.ValueOf(args[i]).Len(); j++ { + //如果是slice,则将cond里的?替代为args的长度个? + str = append(str, fmt.Sprintf(":in%d", j)) + //将args[i]里的元素添加到new_args里 + newArgs = append(newArgs, sql.Named(fmt.Sprintf("in%d", j), reflect.ValueOf(args[i]).Index(j).Interface())) + } + cond = strings.Replace(cond, "(:in)", fmt.Sprintf("(%s)", strings.Join(str, ",")), -1) + } else { + newArgs = append(newArgs, args[i]) + } + } + return cond, newArgs +} + +func dealOracleCondition(cond string, args []interface{}) (string, []interface{}) { + re := regexp.MustCompile(`:(\w+)`) + matches := re.FindAllStringSubmatch(cond, -1) + var result []string + for _, match := range matches { + if len(match) > 1 { + result = append(result, match[1]) + } + } + newArgs := make([]interface{}, 0) + for i := 0; i < len(args); i++ { + if reflect.ValueOf(args[i]).Kind() == reflect.Slice { + for j := 0; j < reflect.ValueOf(args[i]).Len(); j++ { + newArgs = append(newArgs, sql.Named(fmt.Sprintf("%s", result[j]), reflect.ValueOf(args[i]).Index(j).Interface())) + } + } else { + newArgs = append(newArgs, sql.Named(fmt.Sprintf("%s", result[i]), args[i])) + } + } + return cond, newArgs +} diff --git a/db/db.go b/db/db.go index a000841..983dce5 100644 --- a/db/db.go +++ b/db/db.go @@ -12,6 +12,7 @@ import ( type SqlConnector[T tydb.QOneRowRst | tytaos.QOneRowRst | tyoracle.QOneRowRst | int] interface { QueryEx(sql string, data interface{}, args ...interface{}) error QueryRow(sql string, data interface{}, args ...interface{}) (error, T) + DDL(sql string) error } type SqlBuilder[T tydb.QOneRowRst | tytaos.QOneRowRst | tyoracle.QOneRowRst | int] struct { @@ -38,7 +39,7 @@ func (s *SqlBuilder[T]) Clear() { func (s *SqlBuilder[T]) SetSchema(schema string) error { if s.Driver == util.DriverOracle { - err := tyoracle.DB().DDL(fmt.Sprintf("ALTER SESSION SET CURRENT_SCHEMA = %s", schema)) + err := s.Connector.DDL(fmt.Sprintf("ALTER SESSION SET CURRENT_SCHEMA = %s", schema)) if err != nil { return err } diff --git a/main_test.go b/main_test.go index 31a2c58..3b2192d 100644 --- a/main_test.go +++ b/main_test.go @@ -60,9 +60,14 @@ func TestOracle(t *testing.T) { var data []Employye sqlb := OracleSqlBuilder() err = sqlb.SetSchema("HR") - fmt.Println(err) - sqls := `select employee_id,first_name,last_name,email from employees` + if err != nil { + fmt.Println(err) + } + var ids = []int{100, 101, 102} + sqls := `select employee_id,first_name,last_name,email from employees where 1=1 @c1 @c2` sqlb.Sql = sqls + sqlb.InsertConditions("@c1", "employee_id in (:in)", ids) + sqlb.InsertConditions("@c2", "first_name = :first_name and last_name = :last_name", "Steven", "King") page := Page{PageIndex: 1, PageSize: 10} err = sqlb.PaginateBySql(&data, &page) page.Data = &data